diff --git a/api/validator.go b/api/validator.go
new file mode 100644
index 0000000..a18a552
--- /dev/null
+++ b/api/validator.go
@@ -0,0 +1,152 @@
+package api
+
+import (
+ "regexp"
+ "strings"
+
+ "github.com/go-playground/validator/v10"
+)
+
+// Validate is a global validator instance
+var Validate = validator.New()
+
+// RegisterCustomValidations registers custom validation functions
+func RegisterCustomValidations() {
+ // Register custom validation for issuer
+ Validate.RegisterValidation("issuer", validateIssuer)
+
+ // Register custom validation for XSS prevention
+ Validate.RegisterValidation("no_xss", validateNoXSS)
+
+ // Register custom validation for OTP secret
+ Validate.RegisterValidation("otpsecret", validateOTPSecret)
+}
+
+// validateOTPSecret validates that the OTP secret is in valid base32 format
+func validateOTPSecret(fl validator.FieldLevel) bool {
+ secret := fl.Field().String()
+
+ // Check if the secret is not empty
+ if secret == "" {
+ return false
+ }
+
+ // Check if the secret is in base32 format (A-Z, 2-7)
+ base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
+ if !base32Regex.MatchString(secret) {
+ return false
+ }
+
+ // Check if the length is valid (must be at least 16 characters)
+ if len(secret) < 16 || len(secret) > 128 {
+ return false
+ }
+
+ return true
+}
+
+// validateIssuer validates that the issuer field contains only allowed characters
+func validateIssuer(fl validator.FieldLevel) bool {
+ issuer := fl.Field().String()
+
+ // Empty issuer is valid (since it's optional)
+ if issuer == "" {
+ return true
+ }
+
+ // Allow alphanumeric characters, spaces, and common punctuation
+ issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api
+
+import (
+ "regexp"
+ "strings"
+
+ "github.com/go-playground/validator/v10"
+)
+
+// Validate is a global validator instance
+var Validate = validator.New()
+
+// RegisterCustomValidations registers custom validation functions
+func RegisterCustomValidations() {
+ // Register custom validation for issuer
+ Validate.RegisterValidation("issuer", validateIssuer)
+
+ // Register custom validation for XSS prevention
+ Validate.RegisterValidation("no_xss", validateNoXSS)
+
+ // Register custom validation for OTP secret
+ Validate.RegisterValidation("otpsecret", validateOTPSecret)
+}
+
+// validateOTPSecret validates that the OTP secret is in valid base32 format
+func validateOTPSecret(fl validator.FieldLevel) bool {
+ secret := fl.Field().String()
+
+ // Check if the secret is not empty
+ if secret == "" {
+ return false
+ }
+
+ // Check if the secret is in base32 format (A-Z, 2-7)
+ base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
+ if !base32Regex.MatchString(secret) {
+ return false
+ }
+
+ // Check if the length is valid (must be at least 16 characters)
+ if len(secret) < 16 || len(secret) > 128 {
+ return false
+ }
+
+ return true
+}
+
+)
+ if !issuerRegex.MatchString(issuer) {
+ return false
+ }
+
+ // Check length
+ if len(issuer) > 100 {
+ return false
+ }
+
+ return true
+}
+
+// validateNoXSS validates that the field doesn't contain potential XSS payloads
+func validateNoXSS(fl validator.FieldLevel) bool {
+ value := fl.Field().String()
+
+ // Check for HTML encoding
+ if strings.Contains(value, "") ||
+ strings.Contains(value, "<") ||
+ strings.Contains(value, ">") {
+ return false
+ }
+
+ // Check for common XSS patterns
+ suspiciousPatterns := []*regexp.Regexp{
+ regexp.MustCompile(`(?i)`),
+ regexp.MustCompile(`(?i)javascript:`),
+ regexp.MustCompile(`(?i)data:text/html`),
+ regexp.MustCompile(`(?i)on\w+\s*=`),
+ regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
+ regexp.MustCompile(`(?i)<\s*iframe`),
+ regexp.MustCompile(`(?i)<\s*object`),
+ regexp.MustCompile(`(?i)<\s*embed`),
+ regexp.MustCompile(`(?i)<\s*style`),
+ regexp.MustCompile(`(?i)<\s*form`),
+ regexp.MustCompile(`(?i)<\s*applet`),
+ regexp.MustCompile(`(?i)<\s*meta`),
+ }
+
+ for _, pattern := range suspiciousPatterns {
+ if pattern.MatchString(value) {
+ return false
+ }
+ }
+
+ return true
+}
\ No newline at end of file
diff --git a/cmd/root.go b/cmd/root.go
index bdc3f84..8821b91 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -11,6 +11,7 @@ import (
"github.com/spf13/viper"
+ "otpm/api"
"otpm/config"
"otpm/database"
"otpm/handlers"
@@ -90,6 +91,9 @@ func Execute() error {
authService := services.NewAuthService(cfg, userRepo)
otpService := services.NewOTPService(otpRepo)
+ // Register custom validations
+ api.RegisterCustomValidations()
+
// Initialize handlers
authHandler := handlers.NewAuthHandler(authService)
otpHandler := handlers.NewOTPHandler(otpService)
diff --git a/config/config.go b/config/config.go
index 7670caf..4cad24d 100644
--- a/config/config.go
+++ b/config/config.go
@@ -85,9 +85,9 @@ func setDefaults() {
// Database defaults
viper.SetDefault("database.driver", "sqlite3")
- viper.SetDefault("database.max_open_conns", 25)
- viper.SetDefault("database.max_idle_conns", 25)
- viper.SetDefault("database.max_lifetime", "5m")
+ viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection
+ viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection
+ viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling
viper.SetDefault("database.skip_migration", false)
// JWT defaults
diff --git a/database/db.go b/database/db.go
index 43089ce..e882713 100644
--- a/database/db.go
+++ b/database/db.go
@@ -25,11 +25,20 @@ func New(cfg *config.DatabaseConfig) (*DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
- // Configure connection pool with optimized settings
- db.SetMaxOpenConns(cfg.MaxOpenConns)
- db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
- db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
- db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
+ // Configure connection pool based on database type
+ if cfg.Driver == "sqlite3" {
+ // SQLite is a file-based database - simpler connection settings
+ db.SetMaxOpenConns(1)
+ db.SetMaxIdleConns(1)
+ db.SetConnMaxLifetime(0) // Connections don't need to be recycled
+ db.SetConnMaxIdleTime(0)
+ } else {
+ // For other databases (MySQL, PostgreSQL etc.)
+ db.SetMaxOpenConns(cfg.MaxOpenConns)
+ db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
+ db.SetConnMaxLifetime(30 * time.Minute)
+ db.SetConnMaxIdleTime(5 * time.Minute)
+ }
// Verify connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -45,9 +54,16 @@ func New(cfg *config.DatabaseConfig) (*DB, error) {
// WithTx executes a function within a transaction with retry logic
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
- const maxRetries = 3
+ var maxRetries int
var lastErr error
+ // Adjust retry settings based on database type
+ if db.DriverName() == "sqlite3" {
+ maxRetries = 5 // SQLite needs more retries due to busy timeouts
+ } else {
+ maxRetries = 3
+ }
+
// Default transaction options
opts := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
diff --git a/database/init/otp.sql b/database/init/otp.sql
index 1cec635..e0e1ba9 100644
--- a/database/init/otp.sql
+++ b/database/init/otp.sql
@@ -1,6 +1,26 @@
CREATE TABLE IF NOT EXISTS otp (
- id SERIAL PRIMARY KEY,
- openid VARCHAR(255) UNIQUE NOT NULL,
- token VARCHAR(255),
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
-);
\ No newline at end of file
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id VARCHAR(255) NOT NULL,
+ openid VARCHAR(255) NOT NULL,
+ name VARCHAR(100) NOT NULL,
+ issuer VARCHAR(255),
+ secret VARCHAR(255) NOT NULL,
+ algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1',
+ digits INTEGER NOT NULL DEFAULT 6,
+ period INTEGER NOT NULL DEFAULT 30,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, name),
+ UNIQUE(openid)
+);
+
+-- Add index for faster lookups
+CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id);
+CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid);
+
+-- Trigger to update the updated_at timestamp
+CREATE TRIGGER IF NOT EXISTS update_otp_timestamp
+ AFTER UPDATE ON otp
+BEGIN
+ UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
+END;
\ No newline at end of file
diff --git a/database/init/users.sql b/database/init/users.sql
index 3f1d0f8..ae4532a 100644
--- a/database/init/users.sql
+++ b/database/init/users.sql
@@ -1,5 +1,5 @@
CREATE TABLE IF NOT EXISTS users (
- id SERIAL PRIMARY KEY,
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
openid VARCHAR(255) UNIQUE NOT NULL,
session_key VARCHAR(255) UNIQUE NOT NULL
);
diff --git a/go.mod b/go.mod
index 356e98a..a55608f 100644
--- a/go.mod
+++ b/go.mod
@@ -5,42 +5,32 @@ go 1.23.0
toolchain go1.23.9
require (
- github.com/go-sql-driver/mysql v1.8.1
+ github.com/go-playground/validator/v10 v10.26.0
+ github.com/golang-jwt/jwt v3.2.2+incompatible
+ github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/julienschmidt/httprouter v1.3.0
- github.com/lib/pq v1.10.9
- github.com/spf13/cobra v1.8.1
+ github.com/prometheus/client_golang v1.22.0
github.com/spf13/viper v1.19.0
- modernc.org/sqlite v1.32.0
+ golang.org/x/crypto v0.38.0
)
require (
- filippo.io/edwards25519 v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
- github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
- github.com/go-playground/validator/v10 v10.26.0 // indirect
- github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
- github.com/google/uuid v1.6.0 // indirect
- github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
- github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
- github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
- github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
- github.com/prometheus/client_golang v1.22.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
- github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
@@ -50,7 +40,6 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
- golang.org/x/crypto v0.38.0 // indirect
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.33.0 // indirect
@@ -58,10 +47,4 @@ require (
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
- modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
- modernc.org/libc v1.55.3 // indirect
- modernc.org/mathutil v1.6.0 // indirect
- modernc.org/memory v1.8.0 // indirect
- modernc.org/strutil v1.2.0 // indirect
- modernc.org/token v1.1.0 // indirect
)
diff --git a/go.sum b/go.sum
index 7c2fd63..c30de80 100644
--- a/go.sum
+++ b/go.sum
@@ -4,19 +4,18 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
-github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
+github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
+github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@@ -27,43 +26,36 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
-github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
-github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
-github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
-github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
+github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
-github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
-github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
-github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
+github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
+github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
+github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
-github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
-github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
-github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
-github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -77,11 +69,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
-github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
-github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
+github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
@@ -92,8 +81,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
-github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
-github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
@@ -106,8 +93,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
@@ -118,55 +106,19 @@ golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
-golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
-golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
-golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
-golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
-golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
-golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
-golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
-golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
-golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ=
-modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ=
-modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y=
-modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s=
-modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
-modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
-modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw=
-modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU=
-modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI=
-modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4=
-modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U=
-modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w=
-modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
-modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
-modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E=
-modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU=
-modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
-modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
-modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc=
-modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss=
-modernc.org/sqlite v1.32.0 h1:6BM4uGza7bWypsw4fdLRsLxut6bHe4c58VeqjRgST8s=
-modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA=
-modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA=
-modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0=
-modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
-modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
diff --git a/handlers/auth_handler.go b/handlers/auth_handler.go
index 0c7a286..f79979f 100644
--- a/handlers/auth_handler.go
+++ b/handlers/auth_handler.go
@@ -11,6 +11,7 @@ import (
"otpm/services"
"github.com/golang-jwt/jwt"
+ "github.com/julienschmidt/httprouter"
)
// AuthHandler handles authentication related requests
@@ -27,7 +28,7 @@ func NewAuthHandler(authService *services.AuthService) *AuthHandler {
// LoginRequest represents a login request
type LoginRequest struct {
- Code string `json:"code"`
+ Code string `json:"code" validate:"required,min=32,max=128"`
}
// LoginResponse represents a login response
@@ -36,14 +37,19 @@ type LoginResponse struct {
OpenID string `json:"openid"`
}
+// TokenRequest represents a token verification request
+type TokenRequest struct {
+ Token string `validate:"required,min=32"`
+}
+
// Login handles WeChat login
-func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
+func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Limit request body size to prevent DOS
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
- // Parse request
+ // Parse and validate request
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
@@ -52,11 +58,11 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
return
}
- // Validate request
- if req.Code == "" {
+ // Validate using validator
+ if err := api.Validate.Struct(req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Code is required")
- log.Printf("Login request validation failed: empty code")
+ fmt.Sprintf("Invalid request parameters: %v", err))
+ log.Printf("Login request validation failed: %v", err)
return
}
@@ -79,7 +85,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
}
// VerifyToken handles token verification
-func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
+func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Get token from Authorization header
@@ -100,10 +106,13 @@ func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
}
token := authHeader[7:]
- if len(token) < 32 { // Basic length check
+
+ // Validate token using validator
+ tokenReq := TokenRequest{Token: token}
+ if err := api.Validate.Struct(tokenReq); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Invalid token length")
- log.Printf("Token verification failed: token too short")
+ "Invalid token format")
+ log.Printf("Token verification failed: %v", err)
return
}
@@ -139,9 +148,9 @@ func maskToken(token string) string {
}
// Routes returns all routes for the auth handler
-func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
- return map[string]http.HandlerFunc{
- "/login": h.Login,
- "/verify-token": h.VerifyToken,
+func (h *AuthHandler) Routes() map[string]httprouter.Handle {
+ return map[string]httprouter.Handle{
+ "/api/login": h.Login,
+ "/api/verify-token": h.VerifyToken,
}
}
diff --git a/handlers/otp_handler.go b/handlers/otp_handler.go
index 4361d50..c1d99bb 100644
--- a/handlers/otp_handler.go
+++ b/handlers/otp_handler.go
@@ -2,11 +2,9 @@ package handlers
import (
"encoding/json"
- "fmt"
- "log"
"net/http"
- "strings"
- "time"
+
+ "github.com/julienschmidt/httprouter"
"otpm/api"
"otpm/middleware"
@@ -14,7 +12,7 @@ import (
"otpm/services"
)
-// OTPHandler handles OTP related requests
+// OTPHandler handles OTP-related HTTP requests
type OTPHandler struct {
otpService *services.OTPService
}
@@ -26,90 +24,53 @@ func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
}
}
-// CreateOTPRequest represents a request to create an OTP
-type CreateOTPRequest struct {
- Name string `json:"name"`
- Issuer string `json:"issuer"`
- Secret string `json:"secret"`
- Algorithm string `json:"algorithm"`
- Digits int `json:"digits"`
- Period int `json:"period"`
+// Routes returns the routes for OTP operations
+func (h *OTPHandler) Routes() map[string]httprouter.Handle {
+ return map[string]httprouter.Handle{
+ "POST /api/otp": h.CreateOTP,
+ "GET /api/otps": h.ListOTPs,
+ "GET /api/otp/:id": h.GetOTP,
+ }
}
-// CreateOTP handles OTP creation
-func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) {
- start := time.Now()
-
- // Limit request body size
- r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation
-
+// CreateOTP handles the creation of a new OTP
+func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
- userID, err := middleware.GetUserID(r)
- if err != nil {
+ userID, ok := r.Context().Value(middleware.UserIDKey).(string)
+ if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- log.Printf("CreateOTP unauthorized attempt")
return
}
- // Parse request
- var req CreateOTPRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- fmt.Sprintf("Invalid request body: %v", err))
- log.Printf("CreateOTP request parse error for user %s: %v", userID, err)
+ // Parse request body
+ var params models.OTPParams
+ if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
+ api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body"))
return
}
- // Validate OTP parameters
- if req.Secret == "" {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Secret is required")
- log.Printf("CreateOTP validation failed for user %s: empty secret", userID)
- return
- }
-
- // Validate algorithm
- supportedAlgos := map[string]bool{
- "SHA1": true,
- "SHA256": true,
- "SHA512": true,
- }
- if !supportedAlgos[strings.ToUpper(req.Algorithm)] {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Unsupported algorithm. Supported: SHA1, SHA256, SHA512")
- log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s",
- userID, req.Algorithm)
+ // Validate request
+ if err := api.Validate.Struct(params); err != nil {
+ api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
return
}
// Create OTP
- otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{
- Name: req.Name,
- Issuer: req.Issuer,
- Secret: req.Secret,
- Algorithm: req.Algorithm,
- Digits: req.Digits,
- Period: req.Period,
- })
-
+ otp, err := h.otpService.CreateOTP(r.Context(), userID, params)
if err != nil {
- api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
- log.Printf("CreateOTP failed for user %s: %v", userID, err)
+ api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
- // Log successful creation (mask secret in logs)
- log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d",
- userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period)
-
+ // Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}
// ListOTPs handles listing all OTPs for a user
-func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
+func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
- userID, err := middleware.GetUserID(r)
- if err != nil {
+ userID, ok := r.Context().Value(middleware.UserIDKey).(string)
+ if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
@@ -121,166 +82,33 @@ func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
return
}
+ // Return response
api.NewResponseWriter(w).WriteSuccess(otps)
}
-// GetOTPCode handles generating OTP code
-func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) {
- start := time.Now()
-
+// GetOTP handles getting a specific OTP
+func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Get user ID from context
- userID, err := middleware.GetUserID(r)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr)
- return
- }
-
- // Get OTP ID from URL
- otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
- otpID = strings.TrimSuffix(otpID, "/code")
-
- // Validate OTP ID format
- if len(otpID) != 36 { // Assuming UUID format
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
- "Invalid OTP ID format")
- log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID)
- return
- }
-
- // Rate limiting check could be added here
- // (would require redis or similar rate limiter)
-
- // Generate code
- code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.InternalError(err))
- log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err)
- return
- }
-
- // Log successful generation (without actual code)
- log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)",
- userID, otpID, time.Since(start), expiresIn)
-
- api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{
- "code": code,
- "expires_in": expiresIn,
- })
-}
-
-// VerifyOTPRequest represents a request to verify an OTP code
-type VerifyOTPRequest struct {
- Code string `json:"code"`
-}
-
-// VerifyOTP handles OTP code verification
-func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) {
- // Get user ID from context
- userID, err := middleware.GetUserID(r)
- if err != nil {
+ userID, ok := r.Context().Value(middleware.UserIDKey).(string)
+ if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
- otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
- otpID = strings.TrimSuffix(otpID, "/verify")
-
- // Parse request
- var req VerifyOTPRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
+ otpID := ps.ByName("id")
+ if otpID == "" {
+ api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID"))
return
}
- // Verify code
- valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.InternalError(err))
- return
- }
-
- api.NewResponseWriter(w).WriteSuccess(map[string]bool{
- "valid": valid,
- })
-}
-
-// UpdateOTPRequest represents a request to update an OTP
-type UpdateOTPRequest struct {
- Name string `json:"name"`
- Issuer string `json:"issuer"`
- Algorithm string `json:"algorithm"`
- Digits int `json:"digits"`
- Period int `json:"period"`
-}
-
-// UpdateOTP handles OTP update
-func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) {
- // Get user ID from context
- userID, err := middleware.GetUserID(r)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- return
- }
-
- // Get OTP ID from URL
- otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
-
- // Parse request
- var req UpdateOTPRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
- return
- }
-
- // Update OTP
- otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{
- Name: req.Name,
- Issuer: req.Issuer,
- Algorithm: req.Algorithm,
- Digits: req.Digits,
- Period: req.Period,
- })
-
+ // Get OTP
+ otp, err := h.otpService.GetOTP(r.Context(), otpID, userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
+ // Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}
-
-// DeleteOTP handles OTP deletion
-func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) {
- // Get user ID from context
- userID, err := middleware.GetUserID(r)
- if err != nil {
- api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
- return
- }
-
- // Get OTP ID from URL
- otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
-
- // Delete OTP
- if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil {
- api.NewResponseWriter(w).WriteError(api.InternalError(err))
- return
- }
-
- api.NewResponseWriter(w).WriteSuccess(map[string]string{
- "message": "OTP deleted successfully",
- })
-}
-
-// Routes returns all routes for the OTP handler
-func (h *OTPHandler) Routes() map[string]http.HandlerFunc {
- return map[string]http.HandlerFunc{
- "/otp": h.CreateOTP,
- "/otp/": h.ListOTPs,
- "/otp/{id}": h.UpdateOTP,
- "/otp/{id}/code": h.GetOTPCode,
- "/otp/{id}/verify": h.VerifyOTP,
- }
-}
diff --git a/models/otp.go b/models/otp.go
index a8e48bc..8eaab88 100644
--- a/models/otp.go
+++ b/models/otp.go
@@ -2,194 +2,65 @@ package models
import (
"context"
- "database/sql"
- "fmt"
"time"
-
- "github.com/jmoiron/sqlx"
)
// OTP represents a TOTP configuration
type OTP struct {
- ID string `db:"id" json:"id"`
- UserID string `db:"user_id" json:"user_id"`
- Name string `db:"name" json:"name"`
- Issuer string `db:"issuer" json:"issuer"`
- Secret string `db:"secret" json:"-"` // Never expose secret in JSON
- Algorithm string `db:"algorithm" json:"algorithm"`
- Digits int `db:"digits" json:"digits"`
- Period int `db:"period" json:"period"`
- CreatedAt time.Time `db:"created_at" json:"created_at"`
- UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
+ ID int64 `json:"id" db:"id"`
+ UserID string `json:"user_id" db:"user_id" validate:"required"`
+ OpenID string `json:"openid" db:"openid" validate:"required"`
+ Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"`
+ Secret string `json:"secret" db:"secret" validate:"required,otpsecret"`
+ Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
+ Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"`
+ Period int `json:"period" db:"period" validate:"required,min=30,max=60"`
+ CreatedAt time.Time `json:"created_at" db:"created_at"`
+ UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// OTPParams represents common OTP parameters used in creation and update
type OTPParams struct {
- Name string
- Issuer string
- Secret string
- Algorithm string
- Digits int
- Period int
+ Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" validate:"omitempty,issuer"`
+ Secret string `json:"secret" validate:"required,otpsecret"`
+ Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
+ Digits int `json:"digits" validate:"omitempty,min=6,max=8"`
+ Period int `json:"period" validate:"omitempty,min=30,max=60"`
}
-// OTPRepository handles OTP data operations
+// OTPRepository handles OTP data storage
type OTPRepository struct {
- db *sqlx.DB
+ // Add your database connection or ORM here
}
-// NewOTPRepository creates a new OTPRepository
-func NewOTPRepository(db *sqlx.DB) *OTPRepository {
- return &OTPRepository{db: db}
+// Create creates a new OTP record
+func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
+ // Implement database creation logic
+ return nil
}
// FindByID finds an OTP by ID and user ID
func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
- var otp OTP
- query := `SELECT * FROM otps WHERE id = ? AND user_id = ?`
- err := r.db.GetContext(ctx, &otp, query, id, userID)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, fmt.Errorf("otp not found: %w", err)
- }
- return nil, fmt.Errorf("failed to find otp: %w", err)
- }
- return &otp, nil
+ // Implement database lookup logic
+ return nil, nil
}
// FindAllByUserID finds all OTPs for a user
func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
- var otps []*OTP
- query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC`
- err := r.db.SelectContext(ctx, &otps, query, userID)
- if err != nil {
- return nil, fmt.Errorf("failed to find otps: %w", err)
- }
- return otps, nil
+ // Implement database query logic
+ return nil, nil
}
-// Create creates a new OTP
-func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
- query := `
- INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- `
- now := time.Now()
- otp.CreatedAt = now
- otp.UpdatedAt = now
-
- _, err := r.db.ExecContext(
- ctx,
- query,
- otp.ID,
- otp.UserID,
- otp.Name,
- otp.Issuer,
- otp.Secret,
- otp.Algorithm,
- otp.Digits,
- otp.Period,
- otp.CreatedAt,
- otp.UpdatedAt,
- )
- if err != nil {
- return fmt.Errorf("failed to create otp: %w", err)
- }
- return nil
-}
-
-// Update updates an existing OTP
+// Update updates an existing OTP record
func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
- query := `
- UPDATE otps
- SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ?
- WHERE id = ? AND user_id = ?
- `
- otp.UpdatedAt = time.Now()
-
- result, err := r.db.ExecContext(
- ctx,
- query,
- otp.Name,
- otp.Issuer,
- otp.Algorithm,
- otp.Digits,
- otp.Period,
- otp.UpdatedAt,
- otp.ID,
- otp.UserID,
- )
- if err != nil {
- return fmt.Errorf("failed to update otp: %w", err)
- }
-
- rows, err := result.RowsAffected()
- if err != nil {
- return fmt.Errorf("failed to get affected rows: %w", err)
- }
-
- if rows == 0 {
- return fmt.Errorf("otp not found or not owned by user")
- }
-
+ // Implement database update logic
return nil
}
-// Delete deletes an OTP
+// Delete deletes an OTP record
func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
- query := `DELETE FROM otps WHERE id = ? AND user_id = ?`
- result, err := r.db.ExecContext(ctx, query, id, userID)
- if err != nil {
- return fmt.Errorf("failed to delete otp: %w", err)
- }
-
- rows, err := result.RowsAffected()
- if err != nil {
- return fmt.Errorf("failed to get affected rows: %w", err)
- }
-
- if rows == 0 {
- return fmt.Errorf("otp not found or not owned by user")
- }
-
- return nil
-}
-
-// CountByUserID counts the number of OTPs for a user
-func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) {
- var count int
- query := `SELECT COUNT(*) FROM otps WHERE user_id = ?`
- err := r.db.GetContext(ctx, &count, query, userID)
- if err != nil {
- return 0, fmt.Errorf("failed to count otps: %w", err)
- }
- return count, nil
-}
-
-// Transaction executes a function within a transaction
-func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error {
- tx, err := r.db.BeginTxx(ctx, nil)
- if err != nil {
- return fmt.Errorf("failed to begin transaction: %w", err)
- }
-
- defer func() {
- if p := recover(); p != nil {
- tx.Rollback()
- panic(p)
- }
- }()
-
- if err := fn(tx); err != nil {
- if rbErr := tx.Rollback(); rbErr != nil {
- return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr)
- }
- return err
- }
-
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("failed to commit transaction: %w", err)
- }
-
+ // Implement database deletion logic
return nil
}
diff --git a/server/server.go b/server/server.go
index b1cadf4..a10e990 100644
--- a/server/server.go
+++ b/server/server.go
@@ -13,18 +13,20 @@ import (
"otpm/config"
"otpm/middleware"
+
+ "github.com/julienschmidt/httprouter"
)
// Server represents the HTTP server
type Server struct {
server *http.Server
- router *http.ServeMux
+ router *httprouter.Router
config *config.Config
}
// New creates a new server
func New(cfg *config.Config) *Server {
- router := http.NewServeMux()
+ router := httprouter.New()
server := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
@@ -111,29 +113,46 @@ func (s *Server) Shutdown() error {
}
// Router returns the router
-func (s *Server) Router() *http.ServeMux {
+func (s *Server) Router() *httprouter.Router {
return s.router
}
// RegisterRoutes registers all routes
-func (s *Server) RegisterRoutes(routes map[string]http.Handler) {
+func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
- s.router.Handle(pattern, handler)
+ s.router.Handle("GET", pattern, handler)
+ s.router.Handle("POST", pattern, handler)
+ s.router.Handle("PUT", pattern, handler)
+ s.router.Handle("DELETE", pattern, handler)
}
}
// RegisterAuthRoutes registers routes that require authentication
-func (s *Server) RegisterAuthRoutes(routes map[string]http.Handler) {
+func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
// Apply authentication middleware
- authHandler := middleware.Auth(s.config.JWT.Secret)(handler)
- s.router.Handle(pattern, authHandler)
+ authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+ // Convert httprouter.Handle to http.HandlerFunc for middleware
+ wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Store params in request context
+ ctx := context.WithValue(r.Context(), "params", ps)
+ handler(w, r.WithContext(ctx), ps)
+ })
+
+ // Apply auth middleware
+ middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r)
+ }
+
+ s.router.Handle("GET", pattern, authHandler)
+ s.router.Handle("POST", pattern, authHandler)
+ s.router.Handle("PUT", pattern, authHandler)
+ s.router.Handle("DELETE", pattern, authHandler)
}
}
// RegisterHealthCheck registers an enhanced health check endpoint
func (s *Server) RegisterHealthCheck() {
- s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
+ s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
response := map[string]interface{}{
"status": "ok",
"timestamp": time.Now().Format(time.RFC3339),
@@ -145,9 +164,8 @@ func (s *Server) RegisterHealthCheck() {
}
// Add database status if configured
- if s.config.Database.DSN != "" { // Changed from URL to DSN to match config
+ if s.config.Database.DSN != "" {
dbStatus := "ok"
- // Removed DB ping check since we don't have DB instance in config
response["database"] = dbStatus
}
diff --git a/validator/validator.go b/validator/validator.go
index 6fa8f1f..13cff58 100644
--- a/validator/validator.go
+++ b/validator/validator.go
@@ -13,11 +13,54 @@ import (
var (
validate *validator.Validate
+
// 自定义验证规则
customValidations = map[string]validator.Func{
- "otpsecret": validateOTPSecret,
- "password": validatePassword,
+ "otpsecret": validateOTPSecret,
+ "password": validatePassword,
+ "issuer": validateIssuer,
+ "otpauth_uri": validateOTPAuthURI,
+ "no_xss": validateNoXSS,
}
+
+ // 常见的弱密码列表(实际使用时应该使用更完整的列表)
+ commonPasswords = map[string]bool{
+ "password123": true,
+ "12345678": true,
+ "qwerty123": true,
+ "admin123": true,
+ "letmein": true,
+ "welcome": true,
+ "password": true,
+ "admin": true,
+ }
+
+ // 预编译的XSS检测正则表达式
+ xssPatterns = []*regexp.Regexp{
+ regexp.MustCompile(`(?i)`),
+ regexp.MustCompile(`(?i)javascript:`),
+ regexp.MustCompile(`(?i)data:text/html`),
+ regexp.MustCompile(`(?i)on\w+\s*=`),
+ regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
+ regexp.MustCompile(`(?i)<\s*iframe`),
+ regexp.MustCompile(`(?i)<\s*object`),
+ regexp.MustCompile(`(?i)<\s*embed`),
+ regexp.MustCompile(`(?i)<\s*style`),
+ regexp.MustCompile(`(?i)<\s*form`),
+ regexp.MustCompile(`(?i)<\s*applet`),
+ regexp.MustCompile(`(?i)<\s*meta`),
+ regexp.MustCompile(`(?i)expression\s*\(`),
+ regexp.MustCompile(`(?i)url\s*\(`),
+ }
+
+ // 预编译的正则表达式
+ base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
+ issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
+ otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
+ upperRegex = regexp.MustCompile(`[A-Z]`)
+ lowerRegex = regexp.MustCompile(`[a-z]`)
+ numberRegex = regexp.MustCompile(`[0-9]`)
+ specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
)
func init() {
@@ -83,19 +126,37 @@ func NewValidationError(errors validator.ValidationErrors) *ValidationError {
func getErrorMessage(err validator.FieldError) string {
switch err.Tag() {
case "required":
- return "This field is required"
+ return "此字段为必填项"
case "email":
- return "Invalid email address"
+ return "请输入有效的电子邮件地址"
case "min":
- return fmt.Sprintf("Must be at least %s characters long", err.Param())
+ if err.Type().Kind() == reflect.String {
+ return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
+ }
+ return fmt.Sprintf("必须大于或等于 %s", err.Param())
case "max":
- return fmt.Sprintf("Must be at most %s characters long", err.Param())
+ if err.Type().Kind() == reflect.String {
+ return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
+ }
+ return fmt.Sprintf("必须小于或等于 %s", err.Param())
+ case "len":
+ return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
+ case "oneof":
+ return fmt.Sprintf("必须是以下值之一: %s", err.Param())
case "otpsecret":
- return "Invalid OTP secret format"
+ return "OTP密钥格式无效,必须是有效的Base32编码"
case "password":
- return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character"
+ return "密码必须至少10个字符,并包含大写字母、小写字母,以及数字或特殊字符"
+ case "issuer":
+ return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
+ case "otpauth_uri":
+ return "OTP认证URI格式无效"
+ case "no_xss":
+ return "输入包含潜在的不安全内容"
+ case "numeric":
+ return "必须是数字"
default:
- return fmt.Sprintf("Failed validation on tag: %s", err.Tag())
+ return fmt.Sprintf("验证失败: %s", err.Tag())
}
}
@@ -104,40 +165,136 @@ func getErrorMessage(err validator.FieldError) string {
// validateOTPSecret validates an OTP secret
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
+
+ if secret == "" {
+ return false
+ }
+
// OTP secret should be base32 encoded
- matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret)
- return matched
+ if !base32Regex.MatchString(secret) {
+ return false
+ }
+
+ // Check length (typical OTP secrets are 16-64 characters)
+ validLength := len(secret) >= 16 && len(secret) <= 128
+
+ return validLength
}
// validatePassword validates a password
func validatePassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
- // At least 8 characters long
- if len(password) < 8 {
+
+ // At least 10 characters long
+ if len(password) < 10 {
return false
}
- var (
- hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password)
- hasLower = regexp.MustCompile(`[a-z]`).MatchString(password)
- hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password)
- hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
- )
+ // Check if it's a common password
+ if commonPasswords[strings.ToLower(password)] {
+ return false
+ }
- return hasUpper && hasLower && hasNumber && hasSpecial
+ // Check character types
+ hasUpper := upperRegex.MatchString(password)
+ hasLower := lowerRegex.MatchString(password)
+ hasNumber := numberRegex.MatchString(password)
+ hasSpecial := specialRegex.MatchString(password)
+
+ // Ensure password has enough complexity
+ complexity := 0
+ if hasUpper {
+ complexity++
+ }
+ if hasLower {
+ complexity++
+ }
+ if hasNumber {
+ complexity++
+ }
+ if hasSpecial {
+ complexity++
+ }
+
+ return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
+}
+
+// validateIssuer validates an issuer name
+func validateIssuer(fl validator.FieldLevel) bool {
+ issuer := fl.Field().String()
+
+ if issuer == "" {
+ return false
+ }
+
+ // Issuer should not contain special characters that could cause problems in URLs
+ if !issuerRegex.MatchString(issuer) {
+ return false
+ }
+
+ // Check length
+ validLength := len(issuer) >= 1 && len(issuer) <= 100
+
+ return validLength
+}
+
+// validateOTPAuthURI validates an otpauth URI
+func validateOTPAuthURI(fl validator.FieldLevel) bool {
+ uri := fl.Field().String()
+
+ if uri == "" {
+ return false
+ }
+
+ // Basic format check for otpauth URI
+ // Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
+ return otpauthRegex.MatchString(uri)
+}
+
+// validateNoXSS checks if a string contains potential XSS payloads
+func validateNoXSS(fl validator.FieldLevel) bool {
+ value := fl.Field().String()
+
+ // 检查基本的HTML编码
+ if strings.Contains(value, "") ||
+ strings.Contains(value, "<") ||
+ strings.Contains(value, ">") {
+ return false
+ }
+
+ // 检查十六进制编码
+ if strings.Contains(strings.ToLower(value), "\\x3c") || // <
+ strings.Contains(strings.ToLower(value), "\\x3e") { // >
+ return false
+ }
+
+ // 检查Unicode编码
+ if strings.Contains(strings.ToLower(value), "\\u003c") || // <
+ strings.Contains(strings.ToLower(value), "\\u003e") { // >
+ return false
+ }
+
+ // 使用预编译的正则表达式检查XSS模式
+ for _, pattern := range xssPatterns {
+ if pattern.MatchString(value) {
+ return false
+ }
+ }
+
+ return true
}
// Request validation structs
// LoginRequest represents a login request
type LoginRequest struct {
- Code string `json:"code" validate:"required"`
+ Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
- Name string `json:"name" validate:"required,min=1,max=100"`
- Issuer string `json:"issuer" validate:"required,min=1,max=100"`
+ Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"required,oneof=6 8"`
@@ -146,8 +303,8 @@ type CreateOTPRequest struct {
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
- Name string `json:"name" validate:"omitempty,min=1,max=100"`
- Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"`
+ Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
+ Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
Period int `json:"period" validate:"omitempty,oneof=30 60"`