diff --git a/api/validator.go b/api/validator.go new file mode 100644 index 0000000..a18a552 --- /dev/null +++ b/api/validator.go @@ -0,0 +1,152 @@ +package api + +import ( + "regexp" + "strings" + + "github.com/go-playground/validator/v10" +) + +// Validate is a global validator instance +var Validate = validator.New() + +// RegisterCustomValidations registers custom validation functions +func RegisterCustomValidations() { + // Register custom validation for issuer + Validate.RegisterValidation("issuer", validateIssuer) + + // Register custom validation for XSS prevention + Validate.RegisterValidation("no_xss", validateNoXSS) + + // Register custom validation for OTP secret + Validate.RegisterValidation("otpsecret", validateOTPSecret) +} + +// validateOTPSecret validates that the OTP secret is in valid base32 format +func validateOTPSecret(fl validator.FieldLevel) bool { + secret := fl.Field().String() + + // Check if the secret is not empty + if secret == "" { + return false + } + + // Check if the secret is in base32 format (A-Z, 2-7) + base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`) + if !base32Regex.MatchString(secret) { + return false + } + + // Check if the length is valid (must be at least 16 characters) + if len(secret) < 16 || len(secret) > 128 { + return false + } + + return true +} + +// validateIssuer validates that the issuer field contains only allowed characters +func validateIssuer(fl validator.FieldLevel) bool { + issuer := fl.Field().String() + + // Empty issuer is valid (since it's optional) + if issuer == "" { + return true + } + + // Allow alphanumeric characters, spaces, and common punctuation + issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api + +import ( + "regexp" + "strings" + + "github.com/go-playground/validator/v10" +) + +// Validate is a global validator instance +var Validate = validator.New() + +// RegisterCustomValidations registers custom validation functions +func RegisterCustomValidations() { + // Register custom validation for issuer + Validate.RegisterValidation("issuer", validateIssuer) + + // Register custom validation for XSS prevention + Validate.RegisterValidation("no_xss", validateNoXSS) + + // Register custom validation for OTP secret + Validate.RegisterValidation("otpsecret", validateOTPSecret) +} + +// validateOTPSecret validates that the OTP secret is in valid base32 format +func validateOTPSecret(fl validator.FieldLevel) bool { + secret := fl.Field().String() + + // Check if the secret is not empty + if secret == "" { + return false + } + + // Check if the secret is in base32 format (A-Z, 2-7) + base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`) + if !base32Regex.MatchString(secret) { + return false + } + + // Check if the length is valid (must be at least 16 characters) + if len(secret) < 16 || len(secret) > 128 { + return false + } + + return true +} + +) + if !issuerRegex.MatchString(issuer) { + return false + } + + // Check length + if len(issuer) > 100 { + return false + } + + return true +} + +// validateNoXSS validates that the field doesn't contain potential XSS payloads +func validateNoXSS(fl validator.FieldLevel) bool { + value := fl.Field().String() + + // Check for HTML encoding + if strings.Contains(value, "&#") || + strings.Contains(value, "<") || + strings.Contains(value, ">") { + return false + } + + // Check for common XSS patterns + suspiciousPatterns := []*regexp.Regexp{ + regexp.MustCompile(`(?i)]*>.*?`), + regexp.MustCompile(`(?i)javascript:`), + regexp.MustCompile(`(?i)data:text/html`), + regexp.MustCompile(`(?i)on\w+\s*=`), + regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`), + regexp.MustCompile(`(?i)<\s*iframe`), + regexp.MustCompile(`(?i)<\s*object`), + regexp.MustCompile(`(?i)<\s*embed`), + regexp.MustCompile(`(?i)<\s*style`), + regexp.MustCompile(`(?i)<\s*form`), + regexp.MustCompile(`(?i)<\s*applet`), + regexp.MustCompile(`(?i)<\s*meta`), + } + + for _, pattern := range suspiciousPatterns { + if pattern.MatchString(value) { + return false + } + } + + return true +} \ No newline at end of file diff --git a/cmd/root.go b/cmd/root.go index bdc3f84..8821b91 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/viper" + "otpm/api" "otpm/config" "otpm/database" "otpm/handlers" @@ -90,6 +91,9 @@ func Execute() error { authService := services.NewAuthService(cfg, userRepo) otpService := services.NewOTPService(otpRepo) + // Register custom validations + api.RegisterCustomValidations() + // Initialize handlers authHandler := handlers.NewAuthHandler(authService) otpHandler := handlers.NewOTPHandler(otpService) diff --git a/config/config.go b/config/config.go index 7670caf..4cad24d 100644 --- a/config/config.go +++ b/config/config.go @@ -85,9 +85,9 @@ func setDefaults() { // Database defaults viper.SetDefault("database.driver", "sqlite3") - viper.SetDefault("database.max_open_conns", 25) - viper.SetDefault("database.max_idle_conns", 25) - viper.SetDefault("database.max_lifetime", "5m") + viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection + viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection + viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling viper.SetDefault("database.skip_migration", false) // JWT defaults diff --git a/database/db.go b/database/db.go index 43089ce..e882713 100644 --- a/database/db.go +++ b/database/db.go @@ -25,11 +25,20 @@ func New(cfg *config.DatabaseConfig) (*DB, error) { return nil, fmt.Errorf("failed to connect to database: %w", err) } - // Configure connection pool with optimized settings - db.SetMaxOpenConns(cfg.MaxOpenConns) - db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections - db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn - db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes + // Configure connection pool based on database type + if cfg.Driver == "sqlite3" { + // SQLite is a file-based database - simpler connection settings + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) // Connections don't need to be recycled + db.SetConnMaxIdleTime(0) + } else { + // For other databases (MySQL, PostgreSQL etc.) + db.SetMaxOpenConns(cfg.MaxOpenConns) + db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections + db.SetConnMaxLifetime(30 * time.Minute) + db.SetConnMaxIdleTime(5 * time.Minute) + } // Verify connection with timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -45,9 +54,16 @@ func New(cfg *config.DatabaseConfig) (*DB, error) { // WithTx executes a function within a transaction with retry logic func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error { - const maxRetries = 3 + var maxRetries int var lastErr error + // Adjust retry settings based on database type + if db.DriverName() == "sqlite3" { + maxRetries = 5 // SQLite needs more retries due to busy timeouts + } else { + maxRetries = 3 + } + // Default transaction options opts := &sql.TxOptions{ Isolation: sql.LevelReadCommitted, diff --git a/database/init/otp.sql b/database/init/otp.sql index 1cec635..e0e1ba9 100644 --- a/database/init/otp.sql +++ b/database/init/otp.sql @@ -1,6 +1,26 @@ CREATE TABLE IF NOT EXISTS otp ( - id SERIAL PRIMARY KEY, - openid VARCHAR(255) UNIQUE NOT NULL, - token VARCHAR(255), - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP -); \ No newline at end of file + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id VARCHAR(255) NOT NULL, + openid VARCHAR(255) NOT NULL, + name VARCHAR(100) NOT NULL, + issuer VARCHAR(255), + secret VARCHAR(255) NOT NULL, + algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1', + digits INTEGER NOT NULL DEFAULT 6, + period INTEGER NOT NULL DEFAULT 30, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, name), + UNIQUE(openid) +); + +-- Add index for faster lookups +CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id); +CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid); + +-- Trigger to update the updated_at timestamp +CREATE TRIGGER IF NOT EXISTS update_otp_timestamp + AFTER UPDATE ON otp +BEGIN + UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; +END; \ No newline at end of file diff --git a/database/init/users.sql b/database/init/users.sql index 3f1d0f8..ae4532a 100644 --- a/database/init/users.sql +++ b/database/init/users.sql @@ -1,5 +1,5 @@ CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY, + id INTEGER PRIMARY KEY AUTOINCREMENT, openid VARCHAR(255) UNIQUE NOT NULL, session_key VARCHAR(255) UNIQUE NOT NULL ); diff --git a/go.mod b/go.mod index 356e98a..a55608f 100644 --- a/go.mod +++ b/go.mod @@ -5,42 +5,32 @@ go 1.23.0 toolchain go1.23.9 require ( - github.com/go-sql-driver/mysql v1.8.1 + github.com/go-playground/validator/v10 v10.26.0 + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.6.0 github.com/jmoiron/sqlx v1.4.0 github.com/julienschmidt/httprouter v1.3.0 - github.com/lib/pq v1.10.9 - github.com/spf13/cobra v1.8.1 + github.com/prometheus/client_golang v1.22.0 github.com/spf13/viper v1.19.0 - modernc.org/sqlite v1.32.0 + golang.org/x/crypto v0.38.0 ) require ( - filippo.io/edwards25519 v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.26.0 // indirect - github.com/golang-jwt/jwt v3.2.2+incompatible // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/prometheus/client_golang v1.22.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -50,7 +40,6 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/crypto v0.38.0 // indirect golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect golang.org/x/net v0.34.0 // indirect golang.org/x/sys v0.33.0 // indirect @@ -58,10 +47,4 @@ require ( google.golang.org/protobuf v1.36.5 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect - modernc.org/libc v1.55.3 // indirect - modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.8.0 // indirect - modernc.org/strutil v1.2.0 // indirect - modernc.org/token v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index 7c2fd63..c30de80 100644 --- a/go.sum +++ b/go.sum @@ -4,19 +4,18 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= @@ -27,43 +26,36 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= -github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= -github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -77,11 +69,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -92,8 +81,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= -github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= @@ -106,8 +93,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= @@ -118,55 +106,19 @@ golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w= golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= -modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= -modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= -modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s= -modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= -modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= -modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= -modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= -modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= -modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= -modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= -modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= -modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= -modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= -modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= -modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= -modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= -modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= -modernc.org/sqlite v1.32.0 h1:6BM4uGza7bWypsw4fdLRsLxut6bHe4c58VeqjRgST8s= -modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA= -modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= -modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= -modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= -modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/handlers/auth_handler.go b/handlers/auth_handler.go index 0c7a286..f79979f 100644 --- a/handlers/auth_handler.go +++ b/handlers/auth_handler.go @@ -11,6 +11,7 @@ import ( "otpm/services" "github.com/golang-jwt/jwt" + "github.com/julienschmidt/httprouter" ) // AuthHandler handles authentication related requests @@ -27,7 +28,7 @@ func NewAuthHandler(authService *services.AuthService) *AuthHandler { // LoginRequest represents a login request type LoginRequest struct { - Code string `json:"code"` + Code string `json:"code" validate:"required,min=32,max=128"` } // LoginResponse represents a login response @@ -36,14 +37,19 @@ type LoginResponse struct { OpenID string `json:"openid"` } +// TokenRequest represents a token verification request +type TokenRequest struct { + Token string `validate:"required,min=32"` +} + // Login handles WeChat login -func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { +func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { start := time.Now() // Limit request body size to prevent DOS r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request - // Parse request + // Parse and validate request var req LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, @@ -52,11 +58,11 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { return } - // Validate request - if req.Code == "" { + // Validate using validator + if err := api.Validate.Struct(req); err != nil { api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Code is required") - log.Printf("Login request validation failed: empty code") + fmt.Sprintf("Invalid request parameters: %v", err)) + log.Printf("Login request validation failed: %v", err) return } @@ -79,7 +85,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { } // VerifyToken handles token verification -func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) { +func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { start := time.Now() // Get token from Authorization header @@ -100,10 +106,13 @@ func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) { } token := authHeader[7:] - if len(token) < 32 { // Basic length check + + // Validate token using validator + tokenReq := TokenRequest{Token: token} + if err := api.Validate.Struct(tokenReq); err != nil { api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Invalid token length") - log.Printf("Token verification failed: token too short") + "Invalid token format") + log.Printf("Token verification failed: %v", err) return } @@ -139,9 +148,9 @@ func maskToken(token string) string { } // Routes returns all routes for the auth handler -func (h *AuthHandler) Routes() map[string]http.HandlerFunc { - return map[string]http.HandlerFunc{ - "/login": h.Login, - "/verify-token": h.VerifyToken, +func (h *AuthHandler) Routes() map[string]httprouter.Handle { + return map[string]httprouter.Handle{ + "/api/login": h.Login, + "/api/verify-token": h.VerifyToken, } } diff --git a/handlers/otp_handler.go b/handlers/otp_handler.go index 4361d50..c1d99bb 100644 --- a/handlers/otp_handler.go +++ b/handlers/otp_handler.go @@ -2,11 +2,9 @@ package handlers import ( "encoding/json" - "fmt" - "log" "net/http" - "strings" - "time" + + "github.com/julienschmidt/httprouter" "otpm/api" "otpm/middleware" @@ -14,7 +12,7 @@ import ( "otpm/services" ) -// OTPHandler handles OTP related requests +// OTPHandler handles OTP-related HTTP requests type OTPHandler struct { otpService *services.OTPService } @@ -26,90 +24,53 @@ func NewOTPHandler(otpService *services.OTPService) *OTPHandler { } } -// CreateOTPRequest represents a request to create an OTP -type CreateOTPRequest struct { - Name string `json:"name"` - Issuer string `json:"issuer"` - Secret string `json:"secret"` - Algorithm string `json:"algorithm"` - Digits int `json:"digits"` - Period int `json:"period"` +// Routes returns the routes for OTP operations +func (h *OTPHandler) Routes() map[string]httprouter.Handle { + return map[string]httprouter.Handle{ + "POST /api/otp": h.CreateOTP, + "GET /api/otps": h.ListOTPs, + "GET /api/otp/:id": h.GetOTP, + } } -// CreateOTP handles OTP creation -func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) { - start := time.Now() - - // Limit request body size - r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation - +// CreateOTP handles the creation of a new OTP +func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { // Get user ID from context - userID, err := middleware.GetUserID(r) - if err != nil { + userID, ok := r.Context().Value(middleware.UserIDKey).(string) + if !ok { api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - log.Printf("CreateOTP unauthorized attempt") return } - // Parse request - var req CreateOTPRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - fmt.Sprintf("Invalid request body: %v", err)) - log.Printf("CreateOTP request parse error for user %s: %v", userID, err) + // Parse request body + var params models.OTPParams + if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { + api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body")) return } - // Validate OTP parameters - if req.Secret == "" { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Secret is required") - log.Printf("CreateOTP validation failed for user %s: empty secret", userID) - return - } - - // Validate algorithm - supportedAlgos := map[string]bool{ - "SHA1": true, - "SHA256": true, - "SHA512": true, - } - if !supportedAlgos[strings.ToUpper(req.Algorithm)] { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Unsupported algorithm. Supported: SHA1, SHA256, SHA512") - log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s", - userID, req.Algorithm) + // Validate request + if err := api.Validate.Struct(params); err != nil { + api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error())) return } // Create OTP - otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{ - Name: req.Name, - Issuer: req.Issuer, - Secret: req.Secret, - Algorithm: req.Algorithm, - Digits: req.Digits, - Period: req.Period, - }) - + otp, err := h.otpService.CreateOTP(r.Context(), userID, params) if err != nil { - api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error())) - log.Printf("CreateOTP failed for user %s: %v", userID, err) + api.NewResponseWriter(w).WriteError(api.InternalError(err)) return } - // Log successful creation (mask secret in logs) - log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d", - userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period) - + // Return response api.NewResponseWriter(w).WriteSuccess(otp) } // ListOTPs handles listing all OTPs for a user -func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) { +func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { // Get user ID from context - userID, err := middleware.GetUserID(r) - if err != nil { + userID, ok := r.Context().Value(middleware.UserIDKey).(string) + if !ok { api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) return } @@ -121,166 +82,33 @@ func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) { return } + // Return response api.NewResponseWriter(w).WriteSuccess(otps) } -// GetOTPCode handles generating OTP code -func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) { - start := time.Now() - +// GetOTP handles getting a specific OTP +func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { // Get user ID from context - userID, err := middleware.GetUserID(r) - if err != nil { - api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr) - return - } - - // Get OTP ID from URL - otpID := strings.TrimPrefix(r.URL.Path, "/otp/") - otpID = strings.TrimSuffix(otpID, "/code") - - // Validate OTP ID format - if len(otpID) != 36 { // Assuming UUID format - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, - "Invalid OTP ID format") - log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID) - return - } - - // Rate limiting check could be added here - // (would require redis or similar rate limiter) - - // Generate code - code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID) - if err != nil { - api.NewResponseWriter(w).WriteError(api.InternalError(err)) - log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err) - return - } - - // Log successful generation (without actual code) - log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)", - userID, otpID, time.Since(start), expiresIn) - - api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{ - "code": code, - "expires_in": expiresIn, - }) -} - -// VerifyOTPRequest represents a request to verify an OTP code -type VerifyOTPRequest struct { - Code string `json:"code"` -} - -// VerifyOTP handles OTP code verification -func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) { - // Get user ID from context - userID, err := middleware.GetUserID(r) - if err != nil { + userID, ok := r.Context().Value(middleware.UserIDKey).(string) + if !ok { api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) return } // Get OTP ID from URL - otpID := strings.TrimPrefix(r.URL.Path, "/otp/") - otpID = strings.TrimSuffix(otpID, "/verify") - - // Parse request - var req VerifyOTPRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body") + otpID := ps.ByName("id") + if otpID == "" { + api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID")) return } - // Verify code - valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code) - if err != nil { - api.NewResponseWriter(w).WriteError(api.InternalError(err)) - return - } - - api.NewResponseWriter(w).WriteSuccess(map[string]bool{ - "valid": valid, - }) -} - -// UpdateOTPRequest represents a request to update an OTP -type UpdateOTPRequest struct { - Name string `json:"name"` - Issuer string `json:"issuer"` - Algorithm string `json:"algorithm"` - Digits int `json:"digits"` - Period int `json:"period"` -} - -// UpdateOTP handles OTP update -func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) { - // Get user ID from context - userID, err := middleware.GetUserID(r) - if err != nil { - api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - return - } - - // Get OTP ID from URL - otpID := strings.TrimPrefix(r.URL.Path, "/otp/") - - // Parse request - var req UpdateOTPRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body") - return - } - - // Update OTP - otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{ - Name: req.Name, - Issuer: req.Issuer, - Algorithm: req.Algorithm, - Digits: req.Digits, - Period: req.Period, - }) - + // Get OTP + otp, err := h.otpService.GetOTP(r.Context(), otpID, userID) if err != nil { api.NewResponseWriter(w).WriteError(api.InternalError(err)) return } + // Return response api.NewResponseWriter(w).WriteSuccess(otp) } - -// DeleteOTP handles OTP deletion -func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) { - // Get user ID from context - userID, err := middleware.GetUserID(r) - if err != nil { - api.NewResponseWriter(w).WriteError(api.ErrUnauthorized) - return - } - - // Get OTP ID from URL - otpID := strings.TrimPrefix(r.URL.Path, "/otp/") - - // Delete OTP - if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil { - api.NewResponseWriter(w).WriteError(api.InternalError(err)) - return - } - - api.NewResponseWriter(w).WriteSuccess(map[string]string{ - "message": "OTP deleted successfully", - }) -} - -// Routes returns all routes for the OTP handler -func (h *OTPHandler) Routes() map[string]http.HandlerFunc { - return map[string]http.HandlerFunc{ - "/otp": h.CreateOTP, - "/otp/": h.ListOTPs, - "/otp/{id}": h.UpdateOTP, - "/otp/{id}/code": h.GetOTPCode, - "/otp/{id}/verify": h.VerifyOTP, - } -} diff --git a/models/otp.go b/models/otp.go index a8e48bc..8eaab88 100644 --- a/models/otp.go +++ b/models/otp.go @@ -2,194 +2,65 @@ package models import ( "context" - "database/sql" - "fmt" "time" - - "github.com/jmoiron/sqlx" ) // OTP represents a TOTP configuration type OTP struct { - ID string `db:"id" json:"id"` - UserID string `db:"user_id" json:"user_id"` - Name string `db:"name" json:"name"` - Issuer string `db:"issuer" json:"issuer"` - Secret string `db:"secret" json:"-"` // Never expose secret in JSON - Algorithm string `db:"algorithm" json:"algorithm"` - Digits int `db:"digits" json:"digits"` - Period int `db:"period" json:"period"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ID int64 `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id" validate:"required"` + OpenID string `json:"openid" db:"openid" validate:"required"` + Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"` + Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"` + Secret string `json:"secret" db:"secret" validate:"required,otpsecret"` + Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"` + Period int `json:"period" db:"period" validate:"required,min=30,max=60"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` } // OTPParams represents common OTP parameters used in creation and update type OTPParams struct { - Name string - Issuer string - Secret string - Algorithm string - Digits int - Period int + Name string `json:"name" validate:"required,min=1,max=100,no_xss"` + Issuer string `json:"issuer" validate:"omitempty,issuer"` + Secret string `json:"secret" validate:"required,otpsecret"` + Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` + Digits int `json:"digits" validate:"omitempty,min=6,max=8"` + Period int `json:"period" validate:"omitempty,min=30,max=60"` } -// OTPRepository handles OTP data operations +// OTPRepository handles OTP data storage type OTPRepository struct { - db *sqlx.DB + // Add your database connection or ORM here } -// NewOTPRepository creates a new OTPRepository -func NewOTPRepository(db *sqlx.DB) *OTPRepository { - return &OTPRepository{db: db} +// Create creates a new OTP record +func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error { + // Implement database creation logic + return nil } // FindByID finds an OTP by ID and user ID func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) { - var otp OTP - query := `SELECT * FROM otps WHERE id = ? AND user_id = ?` - err := r.db.GetContext(ctx, &otp, query, id, userID) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("otp not found: %w", err) - } - return nil, fmt.Errorf("failed to find otp: %w", err) - } - return &otp, nil + // Implement database lookup logic + return nil, nil } // FindAllByUserID finds all OTPs for a user func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) { - var otps []*OTP - query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC` - err := r.db.SelectContext(ctx, &otps, query, userID) - if err != nil { - return nil, fmt.Errorf("failed to find otps: %w", err) - } - return otps, nil + // Implement database query logic + return nil, nil } -// Create creates a new OTP -func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error { - query := ` - INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - now := time.Now() - otp.CreatedAt = now - otp.UpdatedAt = now - - _, err := r.db.ExecContext( - ctx, - query, - otp.ID, - otp.UserID, - otp.Name, - otp.Issuer, - otp.Secret, - otp.Algorithm, - otp.Digits, - otp.Period, - otp.CreatedAt, - otp.UpdatedAt, - ) - if err != nil { - return fmt.Errorf("failed to create otp: %w", err) - } - return nil -} - -// Update updates an existing OTP +// Update updates an existing OTP record func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error { - query := ` - UPDATE otps - SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ? - WHERE id = ? AND user_id = ? - ` - otp.UpdatedAt = time.Now() - - result, err := r.db.ExecContext( - ctx, - query, - otp.Name, - otp.Issuer, - otp.Algorithm, - otp.Digits, - otp.Period, - otp.UpdatedAt, - otp.ID, - otp.UserID, - ) - if err != nil { - return fmt.Errorf("failed to update otp: %w", err) - } - - rows, err := result.RowsAffected() - if err != nil { - return fmt.Errorf("failed to get affected rows: %w", err) - } - - if rows == 0 { - return fmt.Errorf("otp not found or not owned by user") - } - + // Implement database update logic return nil } -// Delete deletes an OTP +// Delete deletes an OTP record func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error { - query := `DELETE FROM otps WHERE id = ? AND user_id = ?` - result, err := r.db.ExecContext(ctx, query, id, userID) - if err != nil { - return fmt.Errorf("failed to delete otp: %w", err) - } - - rows, err := result.RowsAffected() - if err != nil { - return fmt.Errorf("failed to get affected rows: %w", err) - } - - if rows == 0 { - return fmt.Errorf("otp not found or not owned by user") - } - - return nil -} - -// CountByUserID counts the number of OTPs for a user -func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) { - var count int - query := `SELECT COUNT(*) FROM otps WHERE user_id = ?` - err := r.db.GetContext(ctx, &count, query, userID) - if err != nil { - return 0, fmt.Errorf("failed to count otps: %w", err) - } - return count, nil -} - -// Transaction executes a function within a transaction -func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error { - tx, err := r.db.BeginTxx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - - defer func() { - if p := recover(); p != nil { - tx.Rollback() - panic(p) - } - }() - - if err := fn(tx); err != nil { - if rbErr := tx.Rollback(); rbErr != nil { - return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr) - } - return err - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit transaction: %w", err) - } - + // Implement database deletion logic return nil } diff --git a/server/server.go b/server/server.go index b1cadf4..a10e990 100644 --- a/server/server.go +++ b/server/server.go @@ -13,18 +13,20 @@ import ( "otpm/config" "otpm/middleware" + + "github.com/julienschmidt/httprouter" ) // Server represents the HTTP server type Server struct { server *http.Server - router *http.ServeMux + router *httprouter.Router config *config.Config } // New creates a new server func New(cfg *config.Config) *Server { - router := http.NewServeMux() + router := httprouter.New() server := &http.Server{ Addr: fmt.Sprintf(":%d", cfg.Server.Port), @@ -111,29 +113,46 @@ func (s *Server) Shutdown() error { } // Router returns the router -func (s *Server) Router() *http.ServeMux { +func (s *Server) Router() *httprouter.Router { return s.router } // RegisterRoutes registers all routes -func (s *Server) RegisterRoutes(routes map[string]http.Handler) { +func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) { for pattern, handler := range routes { - s.router.Handle(pattern, handler) + s.router.Handle("GET", pattern, handler) + s.router.Handle("POST", pattern, handler) + s.router.Handle("PUT", pattern, handler) + s.router.Handle("DELETE", pattern, handler) } } // RegisterAuthRoutes registers routes that require authentication -func (s *Server) RegisterAuthRoutes(routes map[string]http.Handler) { +func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) { for pattern, handler := range routes { // Apply authentication middleware - authHandler := middleware.Auth(s.config.JWT.Secret)(handler) - s.router.Handle(pattern, authHandler) + authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + // Convert httprouter.Handle to http.HandlerFunc for middleware + wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Store params in request context + ctx := context.WithValue(r.Context(), "params", ps) + handler(w, r.WithContext(ctx), ps) + }) + + // Apply auth middleware + middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r) + } + + s.router.Handle("GET", pattern, authHandler) + s.router.Handle("POST", pattern, authHandler) + s.router.Handle("PUT", pattern, authHandler) + s.router.Handle("DELETE", pattern, authHandler) } } // RegisterHealthCheck registers an enhanced health check endpoint func (s *Server) RegisterHealthCheck() { - s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { response := map[string]interface{}{ "status": "ok", "timestamp": time.Now().Format(time.RFC3339), @@ -145,9 +164,8 @@ func (s *Server) RegisterHealthCheck() { } // Add database status if configured - if s.config.Database.DSN != "" { // Changed from URL to DSN to match config + if s.config.Database.DSN != "" { dbStatus := "ok" - // Removed DB ping check since we don't have DB instance in config response["database"] = dbStatus } diff --git a/validator/validator.go b/validator/validator.go index 6fa8f1f..13cff58 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -13,11 +13,54 @@ import ( var ( validate *validator.Validate + // 自定义验证规则 customValidations = map[string]validator.Func{ - "otpsecret": validateOTPSecret, - "password": validatePassword, + "otpsecret": validateOTPSecret, + "password": validatePassword, + "issuer": validateIssuer, + "otpauth_uri": validateOTPAuthURI, + "no_xss": validateNoXSS, } + + // 常见的弱密码列表(实际使用时应该使用更完整的列表) + commonPasswords = map[string]bool{ + "password123": true, + "12345678": true, + "qwerty123": true, + "admin123": true, + "letmein": true, + "welcome": true, + "password": true, + "admin": true, + } + + // 预编译的XSS检测正则表达式 + xssPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)]*>.*?`), + regexp.MustCompile(`(?i)javascript:`), + regexp.MustCompile(`(?i)data:text/html`), + regexp.MustCompile(`(?i)on\w+\s*=`), + regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`), + regexp.MustCompile(`(?i)<\s*iframe`), + regexp.MustCompile(`(?i)<\s*object`), + regexp.MustCompile(`(?i)<\s*embed`), + regexp.MustCompile(`(?i)<\s*style`), + regexp.MustCompile(`(?i)<\s*form`), + regexp.MustCompile(`(?i)<\s*applet`), + regexp.MustCompile(`(?i)<\s*meta`), + regexp.MustCompile(`(?i)expression\s*\(`), + regexp.MustCompile(`(?i)url\s*\(`), + } + + // 预编译的正则表达式 + base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`) + issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`) + otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`) + upperRegex = regexp.MustCompile(`[A-Z]`) + lowerRegex = regexp.MustCompile(`[a-z]`) + numberRegex = regexp.MustCompile(`[0-9]`) + specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`) ) func init() { @@ -83,19 +126,37 @@ func NewValidationError(errors validator.ValidationErrors) *ValidationError { func getErrorMessage(err validator.FieldError) string { switch err.Tag() { case "required": - return "This field is required" + return "此字段为必填项" case "email": - return "Invalid email address" + return "请输入有效的电子邮件地址" case "min": - return fmt.Sprintf("Must be at least %s characters long", err.Param()) + if err.Type().Kind() == reflect.String { + return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param()) + } + return fmt.Sprintf("必须大于或等于 %s", err.Param()) case "max": - return fmt.Sprintf("Must be at most %s characters long", err.Param()) + if err.Type().Kind() == reflect.String { + return fmt.Sprintf("长度不能超过 %s 个字符", err.Param()) + } + return fmt.Sprintf("必须小于或等于 %s", err.Param()) + case "len": + return fmt.Sprintf("长度必须为 %s 个字符", err.Param()) + case "oneof": + return fmt.Sprintf("必须是以下值之一: %s", err.Param()) case "otpsecret": - return "Invalid OTP secret format" + return "OTP密钥格式无效,必须是有效的Base32编码" case "password": - return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character" + return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符" + case "issuer": + return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号" + case "otpauth_uri": + return "OTP认证URI格式无效" + case "no_xss": + return "输入包含潜在的不安全内容" + case "numeric": + return "必须是数字" default: - return fmt.Sprintf("Failed validation on tag: %s", err.Tag()) + return fmt.Sprintf("验证失败: %s", err.Tag()) } } @@ -104,40 +165,136 @@ func getErrorMessage(err validator.FieldError) string { // validateOTPSecret validates an OTP secret func validateOTPSecret(fl validator.FieldLevel) bool { secret := fl.Field().String() + + if secret == "" { + return false + } + // OTP secret should be base32 encoded - matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret) - return matched + if !base32Regex.MatchString(secret) { + return false + } + + // Check length (typical OTP secrets are 16-64 characters) + validLength := len(secret) >= 16 && len(secret) <= 128 + + return validLength } // validatePassword validates a password func validatePassword(fl validator.FieldLevel) bool { password := fl.Field().String() - // At least 8 characters long - if len(password) < 8 { + + // At least 10 characters long + if len(password) < 10 { return false } - var ( - hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password) - hasLower = regexp.MustCompile(`[a-z]`).MatchString(password) - hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password) - hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password) - ) + // Check if it's a common password + if commonPasswords[strings.ToLower(password)] { + return false + } - return hasUpper && hasLower && hasNumber && hasSpecial + // Check character types + hasUpper := upperRegex.MatchString(password) + hasLower := lowerRegex.MatchString(password) + hasNumber := numberRegex.MatchString(password) + hasSpecial := specialRegex.MatchString(password) + + // Ensure password has enough complexity + complexity := 0 + if hasUpper { + complexity++ + } + if hasLower { + complexity++ + } + if hasNumber { + complexity++ + } + if hasSpecial { + complexity++ + } + + return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial) +} + +// validateIssuer validates an issuer name +func validateIssuer(fl validator.FieldLevel) bool { + issuer := fl.Field().String() + + if issuer == "" { + return false + } + + // Issuer should not contain special characters that could cause problems in URLs + if !issuerRegex.MatchString(issuer) { + return false + } + + // Check length + validLength := len(issuer) >= 1 && len(issuer) <= 100 + + return validLength +} + +// validateOTPAuthURI validates an otpauth URI +func validateOTPAuthURI(fl validator.FieldLevel) bool { + uri := fl.Field().String() + + if uri == "" { + return false + } + + // Basic format check for otpauth URI + // Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD + return otpauthRegex.MatchString(uri) +} + +// validateNoXSS checks if a string contains potential XSS payloads +func validateNoXSS(fl validator.FieldLevel) bool { + value := fl.Field().String() + + // 检查基本的HTML编码 + if strings.Contains(value, "&#") || + strings.Contains(value, "<") || + strings.Contains(value, ">") { + return false + } + + // 检查十六进制编码 + if strings.Contains(strings.ToLower(value), "\\x3c") || // < + strings.Contains(strings.ToLower(value), "\\x3e") { // > + return false + } + + // 检查Unicode编码 + if strings.Contains(strings.ToLower(value), "\\u003c") || // < + strings.Contains(strings.ToLower(value), "\\u003e") { // > + return false + } + + // 使用预编译的正则表达式检查XSS模式 + for _, pattern := range xssPatterns { + if pattern.MatchString(value) { + return false + } + } + + return true } // Request validation structs // LoginRequest represents a login request type LoginRequest struct { - Code string `json:"code" validate:"required"` + Code string `json:"code" validate:"required,len=6|len=8,numeric"` } // CreateOTPRequest represents a request to create an OTP type CreateOTPRequest struct { - Name string `json:"name" validate:"required,min=1,max=100"` - Issuer string `json:"issuer" validate:"required,min=1,max=100"` + Name string `json:"name" validate:"required,min=1,max=100,no_xss"` + Issuer string `json:"issuer" validate:"required,issuer,no_xss"` Secret string `json:"secret" validate:"required,otpsecret"` Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"` Digits int `json:"digits" validate:"required,oneof=6 8"` @@ -146,8 +303,8 @@ type CreateOTPRequest struct { // UpdateOTPRequest represents a request to update an OTP type UpdateOTPRequest struct { - Name string `json:"name" validate:"omitempty,min=1,max=100"` - Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"` + Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"` + Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"` Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"` Digits int `json:"digits" validate:"omitempty,oneof=6 8"` Period int `json:"period" validate:"omitempty,oneof=30 60"`