error
This commit is contained in:
parent
44500afd3f
commit
5d370e1077
13 changed files with 529 additions and 519 deletions
152
api/validator.go
Normal file
152
api/validator.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// Validate is a global validator instance
|
||||
var Validate = validator.New()
|
||||
|
||||
// RegisterCustomValidations registers custom validation functions
|
||||
func RegisterCustomValidations() {
|
||||
// Register custom validation for issuer
|
||||
Validate.RegisterValidation("issuer", validateIssuer)
|
||||
|
||||
// Register custom validation for XSS prevention
|
||||
Validate.RegisterValidation("no_xss", validateNoXSS)
|
||||
|
||||
// Register custom validation for OTP secret
|
||||
Validate.RegisterValidation("otpsecret", validateOTPSecret)
|
||||
}
|
||||
|
||||
// validateOTPSecret validates that the OTP secret is in valid base32 format
|
||||
func validateOTPSecret(fl validator.FieldLevel) bool {
|
||||
secret := fl.Field().String()
|
||||
|
||||
// Check if the secret is not empty
|
||||
if secret == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the secret is in base32 format (A-Z, 2-7)
|
||||
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
|
||||
if !base32Regex.MatchString(secret) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the length is valid (must be at least 16 characters)
|
||||
if len(secret) < 16 || len(secret) > 128 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateIssuer validates that the issuer field contains only allowed characters
|
||||
func validateIssuer(fl validator.FieldLevel) bool {
|
||||
issuer := fl.Field().String()
|
||||
|
||||
// Empty issuer is valid (since it's optional)
|
||||
if issuer == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Allow alphanumeric characters, spaces, and common punctuation
|
||||
issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// Validate is a global validator instance
|
||||
var Validate = validator.New()
|
||||
|
||||
// RegisterCustomValidations registers custom validation functions
|
||||
func RegisterCustomValidations() {
|
||||
// Register custom validation for issuer
|
||||
Validate.RegisterValidation("issuer", validateIssuer)
|
||||
|
||||
// Register custom validation for XSS prevention
|
||||
Validate.RegisterValidation("no_xss", validateNoXSS)
|
||||
|
||||
// Register custom validation for OTP secret
|
||||
Validate.RegisterValidation("otpsecret", validateOTPSecret)
|
||||
}
|
||||
|
||||
// validateOTPSecret validates that the OTP secret is in valid base32 format
|
||||
func validateOTPSecret(fl validator.FieldLevel) bool {
|
||||
secret := fl.Field().String()
|
||||
|
||||
// Check if the secret is not empty
|
||||
if secret == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the secret is in base32 format (A-Z, 2-7)
|
||||
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
|
||||
if !base32Regex.MatchString(secret) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the length is valid (must be at least 16 characters)
|
||||
if len(secret) < 16 || len(secret) > 128 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
)
|
||||
if !issuerRegex.MatchString(issuer) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(issuer) > 100 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateNoXSS validates that the field doesn't contain potential XSS payloads
|
||||
func validateNoXSS(fl validator.FieldLevel) bool {
|
||||
value := fl.Field().String()
|
||||
|
||||
// Check for HTML encoding
|
||||
if strings.Contains(value, "&#") ||
|
||||
strings.Contains(value, "<") ||
|
||||
strings.Contains(value, ">") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for common XSS patterns
|
||||
suspiciousPatterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
|
||||
regexp.MustCompile(`(?i)javascript:`),
|
||||
regexp.MustCompile(`(?i)data:text/html`),
|
||||
regexp.MustCompile(`(?i)on\w+\s*=`),
|
||||
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
|
||||
regexp.MustCompile(`(?i)<\s*iframe`),
|
||||
regexp.MustCompile(`(?i)<\s*object`),
|
||||
regexp.MustCompile(`(?i)<\s*embed`),
|
||||
regexp.MustCompile(`(?i)<\s*style`),
|
||||
regexp.MustCompile(`(?i)<\s*form`),
|
||||
regexp.MustCompile(`(?i)<\s*applet`),
|
||||
regexp.MustCompile(`(?i)<\s*meta`),
|
||||
}
|
||||
|
||||
for _, pattern := range suspiciousPatterns {
|
||||
if pattern.MatchString(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"otpm/api"
|
||||
"otpm/config"
|
||||
"otpm/database"
|
||||
"otpm/handlers"
|
||||
|
@ -90,6 +91,9 @@ func Execute() error {
|
|||
authService := services.NewAuthService(cfg, userRepo)
|
||||
otpService := services.NewOTPService(otpRepo)
|
||||
|
||||
// Register custom validations
|
||||
api.RegisterCustomValidations()
|
||||
|
||||
// Initialize handlers
|
||||
authHandler := handlers.NewAuthHandler(authService)
|
||||
otpHandler := handlers.NewOTPHandler(otpService)
|
||||
|
|
|
@ -85,9 +85,9 @@ func setDefaults() {
|
|||
|
||||
// Database defaults
|
||||
viper.SetDefault("database.driver", "sqlite3")
|
||||
viper.SetDefault("database.max_open_conns", 25)
|
||||
viper.SetDefault("database.max_idle_conns", 25)
|
||||
viper.SetDefault("database.max_lifetime", "5m")
|
||||
viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection
|
||||
viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection
|
||||
viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling
|
||||
viper.SetDefault("database.skip_migration", false)
|
||||
|
||||
// JWT defaults
|
||||
|
|
|
@ -25,11 +25,20 @@ func New(cfg *config.DatabaseConfig) (*DB, error) {
|
|||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool with optimized settings
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
|
||||
db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
|
||||
db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
|
||||
// Configure connection pool based on database type
|
||||
if cfg.Driver == "sqlite3" {
|
||||
// SQLite is a file-based database - simpler connection settings
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(0) // Connections don't need to be recycled
|
||||
db.SetConnMaxIdleTime(0)
|
||||
} else {
|
||||
// For other databases (MySQL, PostgreSQL etc.)
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
|
||||
db.SetConnMaxLifetime(30 * time.Minute)
|
||||
db.SetConnMaxIdleTime(5 * time.Minute)
|
||||
}
|
||||
|
||||
// Verify connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
@ -45,9 +54,16 @@ func New(cfg *config.DatabaseConfig) (*DB, error) {
|
|||
|
||||
// WithTx executes a function within a transaction with retry logic
|
||||
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
|
||||
const maxRetries = 3
|
||||
var maxRetries int
|
||||
var lastErr error
|
||||
|
||||
// Adjust retry settings based on database type
|
||||
if db.DriverName() == "sqlite3" {
|
||||
maxRetries = 5 // SQLite needs more retries due to busy timeouts
|
||||
} else {
|
||||
maxRetries = 3
|
||||
}
|
||||
|
||||
// Default transaction options
|
||||
opts := &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
|
|
|
@ -1,6 +1,26 @@
|
|||
CREATE TABLE IF NOT EXISTS otp (
|
||||
id SERIAL PRIMARY KEY,
|
||||
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||
token VARCHAR(255),
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
openid VARCHAR(255) NOT NULL,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
issuer VARCHAR(255),
|
||||
secret VARCHAR(255) NOT NULL,
|
||||
algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1',
|
||||
digits INTEGER NOT NULL DEFAULT 6,
|
||||
period INTEGER NOT NULL DEFAULT 30,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, name),
|
||||
UNIQUE(openid)
|
||||
);
|
||||
|
||||
-- Add index for faster lookups
|
||||
CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid);
|
||||
|
||||
-- Trigger to update the updated_at timestamp
|
||||
CREATE TRIGGER IF NOT EXISTS update_otp_timestamp
|
||||
AFTER UPDATE ON otp
|
||||
BEGIN
|
||||
UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END;
|
|
@ -1,5 +1,5 @@
|
|||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
openid VARCHAR(255) UNIQUE NOT NULL,
|
||||
session_key VARCHAR(255) UNIQUE NOT NULL
|
||||
);
|
||||
|
|
27
go.mod
27
go.mod
|
@ -5,42 +5,32 @@ go 1.23.0
|
|||
toolchain go1.23.9
|
||||
|
||||
require (
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
github.com/go-playground/validator/v10 v10.26.0
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/spf13/viper v1.19.0
|
||||
modernc.org/sqlite v1.32.0
|
||||
golang.org/x/crypto v0.38.0
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.26.0 // indirect
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/prometheus/client_golang v1.22.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
|
@ -50,7 +40,6 @@ require (
|
|||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
|
@ -58,10 +47,4 @@ require (
|
|||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
|
||||
modernc.org/libc v1.55.3 // indirect
|
||||
modernc.org/mathutil v1.6.0 // indirect
|
||||
modernc.org/memory v1.8.0 // indirect
|
||||
modernc.org/strutil v1.2.0 // indirect
|
||||
modernc.org/token v1.1.0 // indirect
|
||||
)
|
||||
|
|
74
go.sum
74
go.sum
|
@ -4,19 +4,18 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
|||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
|
@ -27,43 +26,36 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv
|
|||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
|
@ -77,11 +69,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
|||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||
|
@ -92,8 +81,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
|
||||
|
@ -106,8 +93,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
|||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
|
@ -118,55 +106,19 @@ golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
|||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
|
||||
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
||||
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
|
||||
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
|
||||
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ=
|
||||
modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ=
|
||||
modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y=
|
||||
modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s=
|
||||
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
|
||||
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
|
||||
modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw=
|
||||
modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU=
|
||||
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI=
|
||||
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4=
|
||||
modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U=
|
||||
modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w=
|
||||
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
|
||||
modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
|
||||
modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E=
|
||||
modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU=
|
||||
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
|
||||
modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
|
||||
modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc=
|
||||
modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss=
|
||||
modernc.org/sqlite v1.32.0 h1:6BM4uGza7bWypsw4fdLRsLxut6bHe4c58VeqjRgST8s=
|
||||
modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA=
|
||||
modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA=
|
||||
modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"otpm/services"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication related requests
|
||||
|
@ -27,7 +28,7 @@ func NewAuthHandler(authService *services.AuthService) *AuthHandler {
|
|||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code"`
|
||||
Code string `json:"code" validate:"required,min=32,max=128"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a login response
|
||||
|
@ -36,14 +37,19 @@ type LoginResponse struct {
|
|||
OpenID string `json:"openid"`
|
||||
}
|
||||
|
||||
// TokenRequest represents a token verification request
|
||||
type TokenRequest struct {
|
||||
Token string `validate:"required,min=32"`
|
||||
}
|
||||
|
||||
// Login handles WeChat login
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
start := time.Now()
|
||||
|
||||
// Limit request body size to prevent DOS
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
|
||||
|
||||
// Parse request
|
||||
// Parse and validate request
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
|
@ -52,11 +58,11 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.Code == "" {
|
||||
// Validate using validator
|
||||
if err := api.Validate.Struct(req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Code is required")
|
||||
log.Printf("Login request validation failed: empty code")
|
||||
fmt.Sprintf("Invalid request parameters: %v", err))
|
||||
log.Printf("Login request validation failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -79,7 +85,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// VerifyToken handles token verification
|
||||
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
start := time.Now()
|
||||
|
||||
// Get token from Authorization header
|
||||
|
@ -100,10 +106,13 @@ func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
token := authHeader[7:]
|
||||
if len(token) < 32 { // Basic length check
|
||||
|
||||
// Validate token using validator
|
||||
tokenReq := TokenRequest{Token: token}
|
||||
if err := api.Validate.Struct(tokenReq); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid token length")
|
||||
log.Printf("Token verification failed: token too short")
|
||||
"Invalid token format")
|
||||
log.Printf("Token verification failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -139,9 +148,9 @@ func maskToken(token string) string {
|
|||
}
|
||||
|
||||
// Routes returns all routes for the auth handler
|
||||
func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/login": h.Login,
|
||||
"/verify-token": h.VerifyToken,
|
||||
func (h *AuthHandler) Routes() map[string]httprouter.Handle {
|
||||
return map[string]httprouter.Handle{
|
||||
"/api/login": h.Login,
|
||||
"/api/verify-token": h.VerifyToken,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,9 @@ package handlers
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"otpm/api"
|
||||
"otpm/middleware"
|
||||
|
@ -14,7 +12,7 @@ import (
|
|||
"otpm/services"
|
||||
)
|
||||
|
||||
// OTPHandler handles OTP related requests
|
||||
// OTPHandler handles OTP-related HTTP requests
|
||||
type OTPHandler struct {
|
||||
otpService *services.OTPService
|
||||
}
|
||||
|
@ -26,90 +24,53 @@ func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
|
|||
}
|
||||
}
|
||||
|
||||
// CreateOTPRequest represents a request to create an OTP
|
||||
type CreateOTPRequest struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
Secret string `json:"secret"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Digits int `json:"digits"`
|
||||
Period int `json:"period"`
|
||||
// Routes returns the routes for OTP operations
|
||||
func (h *OTPHandler) Routes() map[string]httprouter.Handle {
|
||||
return map[string]httprouter.Handle{
|
||||
"POST /api/otp": h.CreateOTP,
|
||||
"GET /api/otps": h.ListOTPs,
|
||||
"GET /api/otp/:id": h.GetOTP,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateOTP handles OTP creation
|
||||
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Limit request body size
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation
|
||||
|
||||
// CreateOTP handles the creation of a new OTP
|
||||
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
|
||||
if !ok {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("CreateOTP unauthorized attempt")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request
|
||||
var req CreateOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
fmt.Sprintf("Invalid request body: %v", err))
|
||||
log.Printf("CreateOTP request parse error for user %s: %v", userID, err)
|
||||
// Parse request body
|
||||
var params models.OTPParams
|
||||
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body"))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate OTP parameters
|
||||
if req.Secret == "" {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Secret is required")
|
||||
log.Printf("CreateOTP validation failed for user %s: empty secret", userID)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate algorithm
|
||||
supportedAlgos := map[string]bool{
|
||||
"SHA1": true,
|
||||
"SHA256": true,
|
||||
"SHA512": true,
|
||||
}
|
||||
if !supportedAlgos[strings.ToUpper(req.Algorithm)] {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Unsupported algorithm. Supported: SHA1, SHA256, SHA512")
|
||||
log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s",
|
||||
userID, req.Algorithm)
|
||||
// Validate request
|
||||
if err := api.Validate.Struct(params); err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Create OTP
|
||||
otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
Secret: req.Secret,
|
||||
Algorithm: req.Algorithm,
|
||||
Digits: req.Digits,
|
||||
Period: req.Period,
|
||||
})
|
||||
|
||||
otp, err := h.otpService.CreateOTP(r.Context(), userID, params)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
|
||||
log.Printf("CreateOTP failed for user %s: %v", userID, err)
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful creation (mask secret in logs)
|
||||
log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d",
|
||||
userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period)
|
||||
|
||||
// Return response
|
||||
api.NewResponseWriter(w).WriteSuccess(otp)
|
||||
}
|
||||
|
||||
// ListOTPs handles listing all OTPs for a user
|
||||
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
|
||||
if !ok {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
@ -121,166 +82,33 @@ func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Return response
|
||||
api.NewResponseWriter(w).WriteSuccess(otps)
|
||||
}
|
||||
|
||||
// GetOTPCode handles generating OTP code
|
||||
func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// GetOTP handles getting a specific OTP
|
||||
func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
otpID = strings.TrimSuffix(otpID, "/code")
|
||||
|
||||
// Validate OTP ID format
|
||||
if len(otpID) != 36 { // Assuming UUID format
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
|
||||
"Invalid OTP ID format")
|
||||
log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID)
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limiting check could be added here
|
||||
// (would require redis or similar rate limiter)
|
||||
|
||||
// Generate code
|
||||
code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log successful generation (without actual code)
|
||||
log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)",
|
||||
userID, otpID, time.Since(start), expiresIn)
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{
|
||||
"code": code,
|
||||
"expires_in": expiresIn,
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyOTPRequest represents a request to verify an OTP code
|
||||
type VerifyOTPRequest struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// VerifyOTP handles OTP code verification
|
||||
func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
|
||||
if !ok {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
otpID = strings.TrimSuffix(otpID, "/verify")
|
||||
|
||||
// Parse request
|
||||
var req VerifyOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
|
||||
otpID := ps.ByName("id")
|
||||
if otpID == "" {
|
||||
api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID"))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify code
|
||||
valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
|
||||
"valid": valid,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateOTPRequest represents a request to update an OTP
|
||||
type UpdateOTPRequest struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Digits int `json:"digits"`
|
||||
Period int `json:"period"`
|
||||
}
|
||||
|
||||
// UpdateOTP handles OTP update
|
||||
func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
|
||||
// Parse request
|
||||
var req UpdateOTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Update OTP
|
||||
otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
Algorithm: req.Algorithm,
|
||||
Digits: req.Digits,
|
||||
Period: req.Period,
|
||||
})
|
||||
|
||||
// Get OTP
|
||||
otp, err := h.otpService.GetOTP(r.Context(), otpID, userID)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Return response
|
||||
api.NewResponseWriter(w).WriteSuccess(otp)
|
||||
}
|
||||
|
||||
// DeleteOTP handles OTP deletion
|
||||
func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context
|
||||
userID, err := middleware.GetUserID(r)
|
||||
if err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Get OTP ID from URL
|
||||
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
|
||||
|
||||
// Delete OTP
|
||||
if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil {
|
||||
api.NewResponseWriter(w).WriteError(api.InternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
api.NewResponseWriter(w).WriteSuccess(map[string]string{
|
||||
"message": "OTP deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// Routes returns all routes for the OTP handler
|
||||
func (h *OTPHandler) Routes() map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/otp": h.CreateOTP,
|
||||
"/otp/": h.ListOTPs,
|
||||
"/otp/{id}": h.UpdateOTP,
|
||||
"/otp/{id}/code": h.GetOTPCode,
|
||||
"/otp/{id}/verify": h.VerifyOTP,
|
||||
}
|
||||
}
|
||||
|
|
191
models/otp.go
191
models/otp.go
|
@ -2,194 +2,65 @@ package models
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// OTP represents a TOTP configuration
|
||||
type OTP struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
UserID string `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Issuer string `db:"issuer" json:"issuer"`
|
||||
Secret string `db:"secret" json:"-"` // Never expose secret in JSON
|
||||
Algorithm string `db:"algorithm" json:"algorithm"`
|
||||
Digits int `db:"digits" json:"digits"`
|
||||
Period int `db:"period" json:"period"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ID int64 `json:"id" db:"id"`
|
||||
UserID string `json:"user_id" db:"user_id" validate:"required"`
|
||||
OpenID string `json:"openid" db:"openid" validate:"required"`
|
||||
Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"`
|
||||
Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"`
|
||||
Secret string `json:"secret" db:"secret" validate:"required,otpsecret"`
|
||||
Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
|
||||
Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"`
|
||||
Period int `json:"period" db:"period" validate:"required,min=30,max=60"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
}
|
||||
|
||||
// OTPParams represents common OTP parameters used in creation and update
|
||||
type OTPParams struct {
|
||||
Name string
|
||||
Issuer string
|
||||
Secret string
|
||||
Algorithm string
|
||||
Digits int
|
||||
Period int
|
||||
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
|
||||
Issuer string `json:"issuer" validate:"omitempty,issuer"`
|
||||
Secret string `json:"secret" validate:"required,otpsecret"`
|
||||
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
|
||||
Digits int `json:"digits" validate:"omitempty,min=6,max=8"`
|
||||
Period int `json:"period" validate:"omitempty,min=30,max=60"`
|
||||
}
|
||||
|
||||
// OTPRepository handles OTP data operations
|
||||
// OTPRepository handles OTP data storage
|
||||
type OTPRepository struct {
|
||||
db *sqlx.DB
|
||||
// Add your database connection or ORM here
|
||||
}
|
||||
|
||||
// NewOTPRepository creates a new OTPRepository
|
||||
func NewOTPRepository(db *sqlx.DB) *OTPRepository {
|
||||
return &OTPRepository{db: db}
|
||||
// Create creates a new OTP record
|
||||
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
|
||||
// Implement database creation logic
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByID finds an OTP by ID and user ID
|
||||
func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
|
||||
var otp OTP
|
||||
query := `SELECT * FROM otps WHERE id = ? AND user_id = ?`
|
||||
err := r.db.GetContext(ctx, &otp, query, id, userID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("otp not found: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find otp: %w", err)
|
||||
}
|
||||
return &otp, nil
|
||||
// Implement database lookup logic
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// FindAllByUserID finds all OTPs for a user
|
||||
func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
|
||||
var otps []*OTP
|
||||
query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC`
|
||||
err := r.db.SelectContext(ctx, &otps, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find otps: %w", err)
|
||||
}
|
||||
return otps, nil
|
||||
// Implement database query logic
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create creates a new OTP
|
||||
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
|
||||
query := `
|
||||
INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
now := time.Now()
|
||||
otp.CreatedAt = now
|
||||
otp.UpdatedAt = now
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
otp.ID,
|
||||
otp.UserID,
|
||||
otp.Name,
|
||||
otp.Issuer,
|
||||
otp.Secret,
|
||||
otp.Algorithm,
|
||||
otp.Digits,
|
||||
otp.Period,
|
||||
otp.CreatedAt,
|
||||
otp.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create otp: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates an existing OTP
|
||||
// Update updates an existing OTP record
|
||||
func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
|
||||
query := `
|
||||
UPDATE otps
|
||||
SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ?
|
||||
WHERE id = ? AND user_id = ?
|
||||
`
|
||||
otp.UpdatedAt = time.Now()
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
otp.Name,
|
||||
otp.Issuer,
|
||||
otp.Algorithm,
|
||||
otp.Digits,
|
||||
otp.Period,
|
||||
otp.UpdatedAt,
|
||||
otp.ID,
|
||||
otp.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update otp: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("otp not found or not owned by user")
|
||||
}
|
||||
|
||||
// Implement database update logic
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an OTP
|
||||
// Delete deletes an OTP record
|
||||
func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
|
||||
query := `DELETE FROM otps WHERE id = ? AND user_id = ?`
|
||||
result, err := r.db.ExecContext(ctx, query, id, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete otp: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("otp not found or not owned by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountByUserID counts the number of OTPs for a user
|
||||
func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) {
|
||||
var count int
|
||||
query := `SELECT COUNT(*) FROM otps WHERE user_id = ?`
|
||||
err := r.db.GetContext(ctx, &count, query, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count otps: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Transaction executes a function within a transaction
|
||||
func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error {
|
||||
tx, err := r.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
tx.Rollback()
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// Implement database deletion logic
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -13,18 +13,20 @@ import (
|
|||
|
||||
"otpm/config"
|
||||
"otpm/middleware"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
// Server represents the HTTP server
|
||||
type Server struct {
|
||||
server *http.Server
|
||||
router *http.ServeMux
|
||||
router *httprouter.Router
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// New creates a new server
|
||||
func New(cfg *config.Config) *Server {
|
||||
router := http.NewServeMux()
|
||||
router := httprouter.New()
|
||||
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
|
@ -111,29 +113,46 @@ func (s *Server) Shutdown() error {
|
|||
}
|
||||
|
||||
// Router returns the router
|
||||
func (s *Server) Router() *http.ServeMux {
|
||||
func (s *Server) Router() *httprouter.Router {
|
||||
return s.router
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all routes
|
||||
func (s *Server) RegisterRoutes(routes map[string]http.Handler) {
|
||||
func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) {
|
||||
for pattern, handler := range routes {
|
||||
s.router.Handle(pattern, handler)
|
||||
s.router.Handle("GET", pattern, handler)
|
||||
s.router.Handle("POST", pattern, handler)
|
||||
s.router.Handle("PUT", pattern, handler)
|
||||
s.router.Handle("DELETE", pattern, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAuthRoutes registers routes that require authentication
|
||||
func (s *Server) RegisterAuthRoutes(routes map[string]http.Handler) {
|
||||
func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) {
|
||||
for pattern, handler := range routes {
|
||||
// Apply authentication middleware
|
||||
authHandler := middleware.Auth(s.config.JWT.Secret)(handler)
|
||||
s.router.Handle(pattern, authHandler)
|
||||
authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
// Convert httprouter.Handle to http.HandlerFunc for middleware
|
||||
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Store params in request context
|
||||
ctx := context.WithValue(r.Context(), "params", ps)
|
||||
handler(w, r.WithContext(ctx), ps)
|
||||
})
|
||||
|
||||
// Apply auth middleware
|
||||
middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
s.router.Handle("GET", pattern, authHandler)
|
||||
s.router.Handle("POST", pattern, authHandler)
|
||||
s.router.Handle("PUT", pattern, authHandler)
|
||||
s.router.Handle("DELETE", pattern, authHandler)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHealthCheck registers an enhanced health check endpoint
|
||||
func (s *Server) RegisterHealthCheck() {
|
||||
s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
response := map[string]interface{}{
|
||||
"status": "ok",
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
|
@ -145,9 +164,8 @@ func (s *Server) RegisterHealthCheck() {
|
|||
}
|
||||
|
||||
// Add database status if configured
|
||||
if s.config.Database.DSN != "" { // Changed from URL to DSN to match config
|
||||
if s.config.Database.DSN != "" {
|
||||
dbStatus := "ok"
|
||||
// Removed DB ping check since we don't have DB instance in config
|
||||
response["database"] = dbStatus
|
||||
}
|
||||
|
||||
|
|
|
@ -13,11 +13,54 @@ import (
|
|||
|
||||
var (
|
||||
validate *validator.Validate
|
||||
|
||||
// 自定义验证规则
|
||||
customValidations = map[string]validator.Func{
|
||||
"otpsecret": validateOTPSecret,
|
||||
"password": validatePassword,
|
||||
"otpsecret": validateOTPSecret,
|
||||
"password": validatePassword,
|
||||
"issuer": validateIssuer,
|
||||
"otpauth_uri": validateOTPAuthURI,
|
||||
"no_xss": validateNoXSS,
|
||||
}
|
||||
|
||||
// 常见的弱密码列表(实际使用时应该使用更完整的列表)
|
||||
commonPasswords = map[string]bool{
|
||||
"password123": true,
|
||||
"12345678": true,
|
||||
"qwerty123": true,
|
||||
"admin123": true,
|
||||
"letmein": true,
|
||||
"welcome": true,
|
||||
"password": true,
|
||||
"admin": true,
|
||||
}
|
||||
|
||||
// 预编译的XSS检测正则表达式
|
||||
xssPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
|
||||
regexp.MustCompile(`(?i)javascript:`),
|
||||
regexp.MustCompile(`(?i)data:text/html`),
|
||||
regexp.MustCompile(`(?i)on\w+\s*=`),
|
||||
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
|
||||
regexp.MustCompile(`(?i)<\s*iframe`),
|
||||
regexp.MustCompile(`(?i)<\s*object`),
|
||||
regexp.MustCompile(`(?i)<\s*embed`),
|
||||
regexp.MustCompile(`(?i)<\s*style`),
|
||||
regexp.MustCompile(`(?i)<\s*form`),
|
||||
regexp.MustCompile(`(?i)<\s*applet`),
|
||||
regexp.MustCompile(`(?i)<\s*meta`),
|
||||
regexp.MustCompile(`(?i)expression\s*\(`),
|
||||
regexp.MustCompile(`(?i)url\s*\(`),
|
||||
}
|
||||
|
||||
// 预编译的正则表达式
|
||||
base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
|
||||
issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
|
||||
otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
|
||||
upperRegex = regexp.MustCompile(`[A-Z]`)
|
||||
lowerRegex = regexp.MustCompile(`[a-z]`)
|
||||
numberRegex = regexp.MustCompile(`[0-9]`)
|
||||
specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -83,19 +126,37 @@ func NewValidationError(errors validator.ValidationErrors) *ValidationError {
|
|||
func getErrorMessage(err validator.FieldError) string {
|
||||
switch err.Tag() {
|
||||
case "required":
|
||||
return "This field is required"
|
||||
return "此字段为必填项"
|
||||
case "email":
|
||||
return "Invalid email address"
|
||||
return "请输入有效的电子邮件地址"
|
||||
case "min":
|
||||
return fmt.Sprintf("Must be at least %s characters long", err.Param())
|
||||
if err.Type().Kind() == reflect.String {
|
||||
return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
|
||||
}
|
||||
return fmt.Sprintf("必须大于或等于 %s", err.Param())
|
||||
case "max":
|
||||
return fmt.Sprintf("Must be at most %s characters long", err.Param())
|
||||
if err.Type().Kind() == reflect.String {
|
||||
return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
|
||||
}
|
||||
return fmt.Sprintf("必须小于或等于 %s", err.Param())
|
||||
case "len":
|
||||
return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
|
||||
case "oneof":
|
||||
return fmt.Sprintf("必须是以下值之一: %s", err.Param())
|
||||
case "otpsecret":
|
||||
return "Invalid OTP secret format"
|
||||
return "OTP密钥格式无效,必须是有效的Base32编码"
|
||||
case "password":
|
||||
return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character"
|
||||
return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符"
|
||||
case "issuer":
|
||||
return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
|
||||
case "otpauth_uri":
|
||||
return "OTP认证URI格式无效"
|
||||
case "no_xss":
|
||||
return "输入包含潜在的不安全内容"
|
||||
case "numeric":
|
||||
return "必须是数字"
|
||||
default:
|
||||
return fmt.Sprintf("Failed validation on tag: %s", err.Tag())
|
||||
return fmt.Sprintf("验证失败: %s", err.Tag())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,40 +165,136 @@ func getErrorMessage(err validator.FieldError) string {
|
|||
// validateOTPSecret validates an OTP secret
|
||||
func validateOTPSecret(fl validator.FieldLevel) bool {
|
||||
secret := fl.Field().String()
|
||||
|
||||
if secret == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// OTP secret should be base32 encoded
|
||||
matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret)
|
||||
return matched
|
||||
if !base32Regex.MatchString(secret) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check length (typical OTP secrets are 16-64 characters)
|
||||
validLength := len(secret) >= 16 && len(secret) <= 128
|
||||
|
||||
return validLength
|
||||
}
|
||||
|
||||
// validatePassword validates a password
|
||||
func validatePassword(fl validator.FieldLevel) bool {
|
||||
password := fl.Field().String()
|
||||
// At least 8 characters long
|
||||
if len(password) < 8 {
|
||||
|
||||
// At least 10 characters long
|
||||
if len(password) < 10 {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password)
|
||||
hasLower = regexp.MustCompile(`[a-z]`).MatchString(password)
|
||||
hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password)
|
||||
hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
|
||||
)
|
||||
// Check if it's a common password
|
||||
if commonPasswords[strings.ToLower(password)] {
|
||||
return false
|
||||
}
|
||||
|
||||
return hasUpper && hasLower && hasNumber && hasSpecial
|
||||
// Check character types
|
||||
hasUpper := upperRegex.MatchString(password)
|
||||
hasLower := lowerRegex.MatchString(password)
|
||||
hasNumber := numberRegex.MatchString(password)
|
||||
hasSpecial := specialRegex.MatchString(password)
|
||||
|
||||
// Ensure password has enough complexity
|
||||
complexity := 0
|
||||
if hasUpper {
|
||||
complexity++
|
||||
}
|
||||
if hasLower {
|
||||
complexity++
|
||||
}
|
||||
if hasNumber {
|
||||
complexity++
|
||||
}
|
||||
if hasSpecial {
|
||||
complexity++
|
||||
}
|
||||
|
||||
return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
|
||||
}
|
||||
|
||||
// validateIssuer validates an issuer name
|
||||
func validateIssuer(fl validator.FieldLevel) bool {
|
||||
issuer := fl.Field().String()
|
||||
|
||||
if issuer == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Issuer should not contain special characters that could cause problems in URLs
|
||||
if !issuerRegex.MatchString(issuer) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check length
|
||||
validLength := len(issuer) >= 1 && len(issuer) <= 100
|
||||
|
||||
return validLength
|
||||
}
|
||||
|
||||
// validateOTPAuthURI validates an otpauth URI
|
||||
func validateOTPAuthURI(fl validator.FieldLevel) bool {
|
||||
uri := fl.Field().String()
|
||||
|
||||
if uri == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Basic format check for otpauth URI
|
||||
// Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
|
||||
return otpauthRegex.MatchString(uri)
|
||||
}
|
||||
|
||||
// validateNoXSS checks if a string contains potential XSS payloads
|
||||
func validateNoXSS(fl validator.FieldLevel) bool {
|
||||
value := fl.Field().String()
|
||||
|
||||
// 检查基本的HTML编码
|
||||
if strings.Contains(value, "&#") ||
|
||||
strings.Contains(value, "<") ||
|
||||
strings.Contains(value, ">") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查十六进制编码
|
||||
if strings.Contains(strings.ToLower(value), "\\x3c") || // <
|
||||
strings.Contains(strings.ToLower(value), "\\x3e") { // >
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查Unicode编码
|
||||
if strings.Contains(strings.ToLower(value), "\\u003c") || // <
|
||||
strings.Contains(strings.ToLower(value), "\\u003e") { // >
|
||||
return false
|
||||
}
|
||||
|
||||
// 使用预编译的正则表达式检查XSS模式
|
||||
for _, pattern := range xssPatterns {
|
||||
if pattern.MatchString(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Request validation structs
|
||||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Code string `json:"code" validate:"required"`
|
||||
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
|
||||
}
|
||||
|
||||
// CreateOTPRequest represents a request to create an OTP
|
||||
type CreateOTPRequest struct {
|
||||
Name string `json:"name" validate:"required,min=1,max=100"`
|
||||
Issuer string `json:"issuer" validate:"required,min=1,max=100"`
|
||||
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
|
||||
Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
|
||||
Secret string `json:"secret" validate:"required,otpsecret"`
|
||||
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
|
||||
Digits int `json:"digits" validate:"required,oneof=6 8"`
|
||||
|
@ -146,8 +303,8 @@ type CreateOTPRequest struct {
|
|||
|
||||
// UpdateOTPRequest represents a request to update an OTP
|
||||
type UpdateOTPRequest struct {
|
||||
Name string `json:"name" validate:"omitempty,min=1,max=100"`
|
||||
Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"`
|
||||
Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
|
||||
Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
|
||||
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
|
||||
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
|
||||
Period int `json:"period" validate:"omitempty,oneof=30 60"`
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue