This commit is contained in:
“xHuPo” 2025-05-27 17:44:24 +08:00
parent 44500afd3f
commit 5d370e1077
13 changed files with 529 additions and 519 deletions

152
api/validator.go Normal file
View file

@ -0,0 +1,152 @@
package api
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
// Validate is a global validator instance
var Validate = validator.New()
// RegisterCustomValidations registers custom validation functions
func RegisterCustomValidations() {
// Register custom validation for issuer
Validate.RegisterValidation("issuer", validateIssuer)
// Register custom validation for XSS prevention
Validate.RegisterValidation("no_xss", validateNoXSS)
// Register custom validation for OTP secret
Validate.RegisterValidation("otpsecret", validateOTPSecret)
}
// validateOTPSecret validates that the OTP secret is in valid base32 format
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
// Check if the secret is not empty
if secret == "" {
return false
}
// Check if the secret is in base32 format (A-Z, 2-7)
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
if !base32Regex.MatchString(secret) {
return false
}
// Check if the length is valid (must be at least 16 characters)
if len(secret) < 16 || len(secret) > 128 {
return false
}
return true
}
// validateIssuer validates that the issuer field contains only allowed characters
func validateIssuer(fl validator.FieldLevel) bool {
issuer := fl.Field().String()
// Empty issuer is valid (since it's optional)
if issuer == "" {
return true
}
// Allow alphanumeric characters, spaces, and common punctuation
issuerRegex := regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,:;!?()[\]{}'"]+package api
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
// Validate is a global validator instance
var Validate = validator.New()
// RegisterCustomValidations registers custom validation functions
func RegisterCustomValidations() {
// Register custom validation for issuer
Validate.RegisterValidation("issuer", validateIssuer)
// Register custom validation for XSS prevention
Validate.RegisterValidation("no_xss", validateNoXSS)
// Register custom validation for OTP secret
Validate.RegisterValidation("otpsecret", validateOTPSecret)
}
// validateOTPSecret validates that the OTP secret is in valid base32 format
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
// Check if the secret is not empty
if secret == "" {
return false
}
// Check if the secret is in base32 format (A-Z, 2-7)
base32Regex := regexp.MustCompile(`^[A-Z2-7]+=*$`)
if !base32Regex.MatchString(secret) {
return false
}
// Check if the length is valid (must be at least 16 characters)
if len(secret) < 16 || len(secret) > 128 {
return false
}
return true
}
)
if !issuerRegex.MatchString(issuer) {
return false
}
// Check length
if len(issuer) > 100 {
return false
}
return true
}
// validateNoXSS validates that the field doesn't contain potential XSS payloads
func validateNoXSS(fl validator.FieldLevel) bool {
value := fl.Field().String()
// Check for HTML encoding
if strings.Contains(value, "&#") ||
strings.Contains(value, "&lt;") ||
strings.Contains(value, "&gt;") {
return false
}
// Check for common XSS patterns
suspiciousPatterns := []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)data:text/html`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
regexp.MustCompile(`(?i)<\s*iframe`),
regexp.MustCompile(`(?i)<\s*object`),
regexp.MustCompile(`(?i)<\s*embed`),
regexp.MustCompile(`(?i)<\s*style`),
regexp.MustCompile(`(?i)<\s*form`),
regexp.MustCompile(`(?i)<\s*applet`),
regexp.MustCompile(`(?i)<\s*meta`),
}
for _, pattern := range suspiciousPatterns {
if pattern.MatchString(value) {
return false
}
}
return true
}

View file

@ -11,6 +11,7 @@ import (
"github.com/spf13/viper"
"otpm/api"
"otpm/config"
"otpm/database"
"otpm/handlers"
@ -90,6 +91,9 @@ func Execute() error {
authService := services.NewAuthService(cfg, userRepo)
otpService := services.NewOTPService(otpRepo)
// Register custom validations
api.RegisterCustomValidations()
// Initialize handlers
authHandler := handlers.NewAuthHandler(authService)
otpHandler := handlers.NewOTPHandler(otpService)

View file

@ -85,9 +85,9 @@ func setDefaults() {
// Database defaults
viper.SetDefault("database.driver", "sqlite3")
viper.SetDefault("database.max_open_conns", 25)
viper.SetDefault("database.max_idle_conns", 25)
viper.SetDefault("database.max_lifetime", "5m")
viper.SetDefault("database.max_open_conns", 1) // SQLite only needs 1 connection
viper.SetDefault("database.max_idle_conns", 1) // SQLite only needs 1 connection
viper.SetDefault("database.max_lifetime", "0") // SQLite doesn't benefit from connection recycling
viper.SetDefault("database.skip_migration", false)
// JWT defaults

View file

@ -25,11 +25,20 @@ func New(cfg *config.DatabaseConfig) (*DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Configure connection pool with optimized settings
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
db.SetConnMaxLifetime(30 * time.Minute) // Longer lifetime to reduce connection churn
db.SetConnMaxIdleTime(5 * time.Minute) // Close idle connections after 5 minutes
// Configure connection pool based on database type
if cfg.Driver == "sqlite3" {
// SQLite is a file-based database - simpler connection settings
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0) // Connections don't need to be recycled
db.SetConnMaxIdleTime(0)
} else {
// For other databases (MySQL, PostgreSQL etc.)
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(max(1, cfg.MaxOpenConns/2)) // 50% of max open connections
db.SetConnMaxLifetime(30 * time.Minute)
db.SetConnMaxIdleTime(5 * time.Minute)
}
// Verify connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@ -45,9 +54,16 @@ func New(cfg *config.DatabaseConfig) (*DB, error) {
// WithTx executes a function within a transaction with retry logic
func (db *DB) WithTx(ctx context.Context, fn func(*sqlx.Tx) error) error {
const maxRetries = 3
var maxRetries int
var lastErr error
// Adjust retry settings based on database type
if db.DriverName() == "sqlite3" {
maxRetries = 5 // SQLite needs more retries due to busy timeouts
} else {
maxRetries = 3
}
// Default transaction options
opts := &sql.TxOptions{
Isolation: sql.LevelReadCommitted,

View file

@ -1,6 +1,26 @@
CREATE TABLE IF NOT EXISTS otp (
id SERIAL PRIMARY KEY,
openid VARCHAR(255) UNIQUE NOT NULL,
token VARCHAR(255),
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id VARCHAR(255) NOT NULL,
openid VARCHAR(255) NOT NULL,
name VARCHAR(100) NOT NULL,
issuer VARCHAR(255),
secret VARCHAR(255) NOT NULL,
algorithm VARCHAR(10) NOT NULL DEFAULT 'SHA1',
digits INTEGER NOT NULL DEFAULT 6,
period INTEGER NOT NULL DEFAULT 30,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, name),
UNIQUE(openid)
);
-- Add index for faster lookups
CREATE INDEX IF NOT EXISTS idx_otp_user_id ON otp(user_id);
CREATE INDEX IF NOT EXISTS idx_otp_openid ON otp(openid);
-- Trigger to update the updated_at timestamp
CREATE TRIGGER IF NOT EXISTS update_otp_timestamp
AFTER UPDATE ON otp
BEGIN
UPDATE otp SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
END;

View file

@ -1,5 +1,5 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
id INTEGER PRIMARY KEY AUTOINCREMENT,
openid VARCHAR(255) UNIQUE NOT NULL,
session_key VARCHAR(255) UNIQUE NOT NULL
);

27
go.mod
View file

@ -5,42 +5,32 @@ go 1.23.0
toolchain go1.23.9
require (
github.com/go-sql-driver/mysql v1.8.1
github.com/go-playground/validator/v10 v10.26.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/julienschmidt/httprouter v1.3.0
github.com/lib/pq v1.10.9
github.com/spf13/cobra v1.8.1
github.com/prometheus/client_golang v1.22.0
github.com/spf13/viper v1.19.0
modernc.org/sqlite v1.32.0
golang.org/x/crypto v0.38.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.26.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/prometheus/client_golang v1.22.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
@ -50,7 +40,6 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.38.0 // indirect
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.33.0 // indirect
@ -58,10 +47,4 @@ require (
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
modernc.org/libc v1.55.3 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.8.0 // indirect
modernc.org/strutil v1.2.0 // indirect
modernc.org/token v1.1.0 // indirect
)

74
go.sum
View file

@ -4,19 +4,18 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@ -27,43 +26,36 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -77,11 +69,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
@ -92,8 +81,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
@ -106,8 +93,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
@ -118,55 +106,19 @@ golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ=
modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ=
modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y=
modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s=
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw=
modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4=
modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U=
modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w=
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E=
modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU=
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc=
modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss=
modernc.org/sqlite v1.32.0 h1:6BM4uGza7bWypsw4fdLRsLxut6bHe4c58VeqjRgST8s=
modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA=
modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA=
modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View file

@ -11,6 +11,7 @@ import (
"otpm/services"
"github.com/golang-jwt/jwt"
"github.com/julienschmidt/httprouter"
)
// AuthHandler handles authentication related requests
@ -27,7 +28,7 @@ func NewAuthHandler(authService *services.AuthService) *AuthHandler {
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code"`
Code string `json:"code" validate:"required,min=32,max=128"`
}
// LoginResponse represents a login response
@ -36,14 +37,19 @@ type LoginResponse struct {
OpenID string `json:"openid"`
}
// TokenRequest represents a token verification request
type TokenRequest struct {
Token string `validate:"required,min=32"`
}
// Login handles WeChat login
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Limit request body size to prevent DOS
r.Body = http.MaxBytesReader(w, r.Body, 1024) // 1KB max for login request
// Parse request
// Parse and validate request
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
@ -52,11 +58,11 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
return
}
// Validate request
if req.Code == "" {
// Validate using validator
if err := api.Validate.Struct(req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Code is required")
log.Printf("Login request validation failed: empty code")
fmt.Sprintf("Invalid request parameters: %v", err))
log.Printf("Login request validation failed: %v", err)
return
}
@ -79,7 +85,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
}
// VerifyToken handles token verification
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()
// Get token from Authorization header
@ -100,10 +106,13 @@ func (h *AuthHandler) VerifyToken(w http.ResponseWriter, r *http.Request) {
}
token := authHeader[7:]
if len(token) < 32 { // Basic length check
// Validate token using validator
tokenReq := TokenRequest{Token: token}
if err := api.Validate.Struct(tokenReq); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid token length")
log.Printf("Token verification failed: token too short")
"Invalid token format")
log.Printf("Token verification failed: %v", err)
return
}
@ -139,9 +148,9 @@ func maskToken(token string) string {
}
// Routes returns all routes for the auth handler
func (h *AuthHandler) Routes() map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/login": h.Login,
"/verify-token": h.VerifyToken,
func (h *AuthHandler) Routes() map[string]httprouter.Handle {
return map[string]httprouter.Handle{
"/api/login": h.Login,
"/api/verify-token": h.VerifyToken,
}
}

View file

@ -2,11 +2,9 @@ package handlers
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/julienschmidt/httprouter"
"otpm/api"
"otpm/middleware"
@ -14,7 +12,7 @@ import (
"otpm/services"
)
// OTPHandler handles OTP related requests
// OTPHandler handles OTP-related HTTP requests
type OTPHandler struct {
otpService *services.OTPService
}
@ -26,90 +24,53 @@ func NewOTPHandler(otpService *services.OTPService) *OTPHandler {
}
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
Secret string `json:"secret"`
Algorithm string `json:"algorithm"`
Digits int `json:"digits"`
Period int `json:"period"`
// Routes returns the routes for OTP operations
func (h *OTPHandler) Routes() map[string]httprouter.Handle {
return map[string]httprouter.Handle{
"POST /api/otp": h.CreateOTP,
"GET /api/otps": h.ListOTPs,
"GET /api/otp/:id": h.GetOTP,
}
}
// CreateOTP handles OTP creation
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Limit request body size
r.Body = http.MaxBytesReader(w, r.Body, 10*1024) // 10KB max for OTP creation
// CreateOTP handles the creation of a new OTP
func (h *OTPHandler) CreateOTP(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("CreateOTP unauthorized attempt")
return
}
// Parse request
var req CreateOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
fmt.Sprintf("Invalid request body: %v", err))
log.Printf("CreateOTP request parse error for user %s: %v", userID, err)
// Parse request body
var params models.OTPParams
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError("Invalid request body"))
return
}
// Validate OTP parameters
if req.Secret == "" {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Secret is required")
log.Printf("CreateOTP validation failed for user %s: empty secret", userID)
return
}
// Validate algorithm
supportedAlgos := map[string]bool{
"SHA1": true,
"SHA256": true,
"SHA512": true,
}
if !supportedAlgos[strings.ToUpper(req.Algorithm)] {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Unsupported algorithm. Supported: SHA1, SHA256, SHA512")
log.Printf("CreateOTP validation failed for user %s: unsupported algorithm %s",
userID, req.Algorithm)
// Validate request
if err := api.Validate.Struct(params); err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
return
}
// Create OTP
otp, err := h.otpService.CreateOTP(r.Context(), userID, models.OTPParams{
Name: req.Name,
Issuer: req.Issuer,
Secret: req.Secret,
Algorithm: req.Algorithm,
Digits: req.Digits,
Period: req.Period,
})
otp, err := h.otpService.CreateOTP(r.Context(), userID, params)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ValidationError(err.Error()))
log.Printf("CreateOTP failed for user %s: %v", userID, err)
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Log successful creation (mask secret in logs)
log.Printf("OTP created for user %s (took %v): name=%s issuer=%s algo=%s digits=%d period=%d",
userID, time.Since(start), req.Name, req.Issuer, req.Algorithm, req.Digits, req.Period)
// Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}
// ListOTPs handles listing all OTPs for a user
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
@ -121,166 +82,33 @@ func (h *OTPHandler) ListOTPs(w http.ResponseWriter, r *http.Request) {
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otps)
}
// GetOTPCode handles generating OTP code
func (h *OTPHandler) GetOTPCode(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// GetOTP handles getting a specific OTP
func (h *OTPHandler) GetOTP(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
log.Printf("GetOTPCode unauthorized attempt from IP %s", r.RemoteAddr)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
otpID = strings.TrimSuffix(otpID, "/code")
// Validate OTP ID format
if len(otpID) != 36 { // Assuming UUID format
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams,
"Invalid OTP ID format")
log.Printf("GetOTPCode invalid OTP ID format: %s (user %s)", otpID, userID)
return
}
// Rate limiting check could be added here
// (would require redis or similar rate limiter)
// Generate code
code, expiresIn, err := h.otpService.GenerateCode(r.Context(), otpID, userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
log.Printf("GetOTPCode failed for user %s OTP %s: %v", userID, otpID, err)
return
}
// Log successful generation (without actual code)
log.Printf("OTP code generated for user %s OTP %s (took %v, expires in %ds)",
userID, otpID, time.Since(start), expiresIn)
api.NewResponseWriter(w).WriteSuccess(map[string]interface{}{
"code": code,
"expires_in": expiresIn,
})
}
// VerifyOTPRequest represents a request to verify an OTP code
type VerifyOTPRequest struct {
Code string `json:"code"`
}
// VerifyOTP handles OTP code verification
func (h *OTPHandler) VerifyOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
userID, ok := r.Context().Value(middleware.UserIDKey).(string)
if !ok {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
otpID = strings.TrimSuffix(otpID, "/verify")
// Parse request
var req VerifyOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
otpID := ps.ByName("id")
if otpID == "" {
api.NewResponseWriter(w).WriteError(api.ValidationError("Missing OTP ID"))
return
}
// Verify code
valid, err := h.otpService.VerifyCode(r.Context(), otpID, userID, req.Code)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(map[string]bool{
"valid": valid,
})
}
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
Algorithm string `json:"algorithm"`
Digits int `json:"digits"`
Period int `json:"period"`
}
// UpdateOTP handles OTP update
func (h *OTPHandler) UpdateOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
// Parse request
var req UpdateOTPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.NewResponseWriter(w).WriteErrorWithCode(api.CodeInvalidParams, "Invalid request body")
return
}
// Update OTP
otp, err := h.otpService.UpdateOTP(r.Context(), otpID, userID, models.OTPParams{
Name: req.Name,
Issuer: req.Issuer,
Algorithm: req.Algorithm,
Digits: req.Digits,
Period: req.Period,
})
// Get OTP
otp, err := h.otpService.GetOTP(r.Context(), otpID, userID)
if err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
// Return response
api.NewResponseWriter(w).WriteSuccess(otp)
}
// DeleteOTP handles OTP deletion
func (h *OTPHandler) DeleteOTP(w http.ResponseWriter, r *http.Request) {
// Get user ID from context
userID, err := middleware.GetUserID(r)
if err != nil {
api.NewResponseWriter(w).WriteError(api.ErrUnauthorized)
return
}
// Get OTP ID from URL
otpID := strings.TrimPrefix(r.URL.Path, "/otp/")
// Delete OTP
if err := h.otpService.DeleteOTP(r.Context(), otpID, userID); err != nil {
api.NewResponseWriter(w).WriteError(api.InternalError(err))
return
}
api.NewResponseWriter(w).WriteSuccess(map[string]string{
"message": "OTP deleted successfully",
})
}
// Routes returns all routes for the OTP handler
func (h *OTPHandler) Routes() map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/otp": h.CreateOTP,
"/otp/": h.ListOTPs,
"/otp/{id}": h.UpdateOTP,
"/otp/{id}/code": h.GetOTPCode,
"/otp/{id}/verify": h.VerifyOTP,
}
}

View file

@ -2,194 +2,65 @@ package models
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/jmoiron/sqlx"
)
// OTP represents a TOTP configuration
type OTP struct {
ID string `db:"id" json:"id"`
UserID string `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Issuer string `db:"issuer" json:"issuer"`
Secret string `db:"secret" json:"-"` // Never expose secret in JSON
Algorithm string `db:"algorithm" json:"algorithm"`
Digits int `db:"digits" json:"digits"`
Period int `db:"period" json:"period"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ID int64 `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id" validate:"required"`
OpenID string `json:"openid" db:"openid" validate:"required"`
Name string `json:"name" db:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" db:"issuer" validate:"omitempty,issuer"`
Secret string `json:"secret" db:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" db:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" db:"digits" validate:"required,min=6,max=8"`
Period int `json:"period" db:"period" validate:"required,min=30,max=60"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// OTPParams represents common OTP parameters used in creation and update
type OTPParams struct {
Name string
Issuer string
Secret string
Algorithm string
Digits int
Period int
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"omitempty,issuer"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,min=6,max=8"`
Period int `json:"period" validate:"omitempty,min=30,max=60"`
}
// OTPRepository handles OTP data operations
// OTPRepository handles OTP data storage
type OTPRepository struct {
db *sqlx.DB
// Add your database connection or ORM here
}
// NewOTPRepository creates a new OTPRepository
func NewOTPRepository(db *sqlx.DB) *OTPRepository {
return &OTPRepository{db: db}
// Create creates a new OTP record
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
// Implement database creation logic
return nil
}
// FindByID finds an OTP by ID and user ID
func (r *OTPRepository) FindByID(ctx context.Context, id, userID string) (*OTP, error) {
var otp OTP
query := `SELECT * FROM otps WHERE id = ? AND user_id = ?`
err := r.db.GetContext(ctx, &otp, query, id, userID)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("otp not found: %w", err)
}
return nil, fmt.Errorf("failed to find otp: %w", err)
}
return &otp, nil
// Implement database lookup logic
return nil, nil
}
// FindAllByUserID finds all OTPs for a user
func (r *OTPRepository) FindAllByUserID(ctx context.Context, userID string) ([]*OTP, error) {
var otps []*OTP
query := `SELECT * FROM otps WHERE user_id = ? ORDER BY created_at DESC`
err := r.db.SelectContext(ctx, &otps, query, userID)
if err != nil {
return nil, fmt.Errorf("failed to find otps: %w", err)
}
return otps, nil
// Implement database query logic
return nil, nil
}
// Create creates a new OTP
func (r *OTPRepository) Create(ctx context.Context, otp *OTP) error {
query := `
INSERT INTO otps (id, user_id, name, issuer, secret, algorithm, digits, period, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
now := time.Now()
otp.CreatedAt = now
otp.UpdatedAt = now
_, err := r.db.ExecContext(
ctx,
query,
otp.ID,
otp.UserID,
otp.Name,
otp.Issuer,
otp.Secret,
otp.Algorithm,
otp.Digits,
otp.Period,
otp.CreatedAt,
otp.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create otp: %w", err)
}
return nil
}
// Update updates an existing OTP
// Update updates an existing OTP record
func (r *OTPRepository) Update(ctx context.Context, otp *OTP) error {
query := `
UPDATE otps
SET name = ?, issuer = ?, algorithm = ?, digits = ?, period = ?, updated_at = ?
WHERE id = ? AND user_id = ?
`
otp.UpdatedAt = time.Now()
result, err := r.db.ExecContext(
ctx,
query,
otp.Name,
otp.Issuer,
otp.Algorithm,
otp.Digits,
otp.Period,
otp.UpdatedAt,
otp.ID,
otp.UserID,
)
if err != nil {
return fmt.Errorf("failed to update otp: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return fmt.Errorf("otp not found or not owned by user")
}
// Implement database update logic
return nil
}
// Delete deletes an OTP
// Delete deletes an OTP record
func (r *OTPRepository) Delete(ctx context.Context, id, userID string) error {
query := `DELETE FROM otps WHERE id = ? AND user_id = ?`
result, err := r.db.ExecContext(ctx, query, id, userID)
if err != nil {
return fmt.Errorf("failed to delete otp: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return fmt.Errorf("otp not found or not owned by user")
}
return nil
}
// CountByUserID counts the number of OTPs for a user
func (r *OTPRepository) CountByUserID(ctx context.Context, userID string) (int, error) {
var count int
query := `SELECT COUNT(*) FROM otps WHERE user_id = ?`
err := r.db.GetContext(ctx, &count, query, userID)
if err != nil {
return 0, fmt.Errorf("failed to count otps: %w", err)
}
return count, nil
}
// Transaction executes a function within a transaction
func (r *OTPRepository) Transaction(ctx context.Context, fn func(*sqlx.Tx) error) error {
tx, err := r.db.BeginTxx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx failed: %v, rollback failed: %v", err, rbErr)
}
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
// Implement database deletion logic
return nil
}

View file

@ -13,18 +13,20 @@ import (
"otpm/config"
"otpm/middleware"
"github.com/julienschmidt/httprouter"
)
// Server represents the HTTP server
type Server struct {
server *http.Server
router *http.ServeMux
router *httprouter.Router
config *config.Config
}
// New creates a new server
func New(cfg *config.Config) *Server {
router := http.NewServeMux()
router := httprouter.New()
server := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
@ -111,29 +113,46 @@ func (s *Server) Shutdown() error {
}
// Router returns the router
func (s *Server) Router() *http.ServeMux {
func (s *Server) Router() *httprouter.Router {
return s.router
}
// RegisterRoutes registers all routes
func (s *Server) RegisterRoutes(routes map[string]http.Handler) {
func (s *Server) RegisterRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
s.router.Handle(pattern, handler)
s.router.Handle("GET", pattern, handler)
s.router.Handle("POST", pattern, handler)
s.router.Handle("PUT", pattern, handler)
s.router.Handle("DELETE", pattern, handler)
}
}
// RegisterAuthRoutes registers routes that require authentication
func (s *Server) RegisterAuthRoutes(routes map[string]http.Handler) {
func (s *Server) RegisterAuthRoutes(routes map[string]httprouter.Handle) {
for pattern, handler := range routes {
// Apply authentication middleware
authHandler := middleware.Auth(s.config.JWT.Secret)(handler)
s.router.Handle(pattern, authHandler)
authHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Convert httprouter.Handle to http.HandlerFunc for middleware
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Store params in request context
ctx := context.WithValue(r.Context(), "params", ps)
handler(w, r.WithContext(ctx), ps)
})
// Apply auth middleware
middleware.Auth(s.config.JWT.Secret)(wrappedHandler).ServeHTTP(w, r)
}
s.router.Handle("GET", pattern, authHandler)
s.router.Handle("POST", pattern, authHandler)
s.router.Handle("PUT", pattern, authHandler)
s.router.Handle("DELETE", pattern, authHandler)
}
}
// RegisterHealthCheck registers an enhanced health check endpoint
func (s *Server) RegisterHealthCheck() {
s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
s.router.GET("/health", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
response := map[string]interface{}{
"status": "ok",
"timestamp": time.Now().Format(time.RFC3339),
@ -145,9 +164,8 @@ func (s *Server) RegisterHealthCheck() {
}
// Add database status if configured
if s.config.Database.DSN != "" { // Changed from URL to DSN to match config
if s.config.Database.DSN != "" {
dbStatus := "ok"
// Removed DB ping check since we don't have DB instance in config
response["database"] = dbStatus
}

View file

@ -13,11 +13,54 @@ import (
var (
validate *validator.Validate
// 自定义验证规则
customValidations = map[string]validator.Func{
"otpsecret": validateOTPSecret,
"password": validatePassword,
"otpsecret": validateOTPSecret,
"password": validatePassword,
"issuer": validateIssuer,
"otpauth_uri": validateOTPAuthURI,
"no_xss": validateNoXSS,
}
// 常见的弱密码列表(实际使用时应该使用更完整的列表)
commonPasswords = map[string]bool{
"password123": true,
"12345678": true,
"qwerty123": true,
"admin123": true,
"letmein": true,
"welcome": true,
"password": true,
"admin": true,
}
// 预编译的XSS检测正则表达式
xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)data:text/html`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)<\s*img[^>]*src\s*=`),
regexp.MustCompile(`(?i)<\s*iframe`),
regexp.MustCompile(`(?i)<\s*object`),
regexp.MustCompile(`(?i)<\s*embed`),
regexp.MustCompile(`(?i)<\s*style`),
regexp.MustCompile(`(?i)<\s*form`),
regexp.MustCompile(`(?i)<\s*applet`),
regexp.MustCompile(`(?i)<\s*meta`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
}
// 预编译的正则表达式
base32Regex = regexp.MustCompile(`^[A-Z2-7]+=*$`)
issuerRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`)
otpauthRegex = regexp.MustCompile(`^otpauth://totp/[^:]+:[^?]+\?secret=[A-Z2-7]+=*&`)
upperRegex = regexp.MustCompile(`[A-Z]`)
lowerRegex = regexp.MustCompile(`[a-z]`)
numberRegex = regexp.MustCompile(`[0-9]`)
specialRegex = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`)
)
func init() {
@ -83,19 +126,37 @@ func NewValidationError(errors validator.ValidationErrors) *ValidationError {
func getErrorMessage(err validator.FieldError) string {
switch err.Tag() {
case "required":
return "This field is required"
return "此字段为必填项"
case "email":
return "Invalid email address"
return "请输入有效的电子邮件地址"
case "min":
return fmt.Sprintf("Must be at least %s characters long", err.Param())
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度必须至少为 %s 个字符", err.Param())
}
return fmt.Sprintf("必须大于或等于 %s", err.Param())
case "max":
return fmt.Sprintf("Must be at most %s characters long", err.Param())
if err.Type().Kind() == reflect.String {
return fmt.Sprintf("长度不能超过 %s 个字符", err.Param())
}
return fmt.Sprintf("必须小于或等于 %s", err.Param())
case "len":
return fmt.Sprintf("长度必须为 %s 个字符", err.Param())
case "oneof":
return fmt.Sprintf("必须是以下值之一: %s", err.Param())
case "otpsecret":
return "Invalid OTP secret format"
return "OTP密钥格式无效必须是有效的Base32编码"
case "password":
return "Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character"
return "密码必须至少10个字符并包含大写字母、小写字母以及数字或特殊字符"
case "issuer":
return "发行者名称包含无效字符,只允许字母、数字、空格和常见标点符号"
case "otpauth_uri":
return "OTP认证URI格式无效"
case "no_xss":
return "输入包含潜在的不安全内容"
case "numeric":
return "必须是数字"
default:
return fmt.Sprintf("Failed validation on tag: %s", err.Tag())
return fmt.Sprintf("验证失败: %s", err.Tag())
}
}
@ -104,40 +165,136 @@ func getErrorMessage(err validator.FieldError) string {
// validateOTPSecret validates an OTP secret
func validateOTPSecret(fl validator.FieldLevel) bool {
secret := fl.Field().String()
if secret == "" {
return false
}
// OTP secret should be base32 encoded
matched, _ := regexp.MatchString(`^[A-Z2-7]+=*$`, secret)
return matched
if !base32Regex.MatchString(secret) {
return false
}
// Check length (typical OTP secrets are 16-64 characters)
validLength := len(secret) >= 16 && len(secret) <= 128
return validLength
}
// validatePassword validates a password
func validatePassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
// At least 8 characters long
if len(password) < 8 {
// At least 10 characters long
if len(password) < 10 {
return false
}
var (
hasUpper = regexp.MustCompile(`[A-Z]`).MatchString(password)
hasLower = regexp.MustCompile(`[a-z]`).MatchString(password)
hasNumber = regexp.MustCompile(`[0-9]`).MatchString(password)
hasSpecial = regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
)
// Check if it's a common password
if commonPasswords[strings.ToLower(password)] {
return false
}
return hasUpper && hasLower && hasNumber && hasSpecial
// Check character types
hasUpper := upperRegex.MatchString(password)
hasLower := lowerRegex.MatchString(password)
hasNumber := numberRegex.MatchString(password)
hasSpecial := specialRegex.MatchString(password)
// Ensure password has enough complexity
complexity := 0
if hasUpper {
complexity++
}
if hasLower {
complexity++
}
if hasNumber {
complexity++
}
if hasSpecial {
complexity++
}
return complexity >= 3 && hasUpper && hasLower && (hasNumber || hasSpecial)
}
// validateIssuer validates an issuer name
func validateIssuer(fl validator.FieldLevel) bool {
issuer := fl.Field().String()
if issuer == "" {
return false
}
// Issuer should not contain special characters that could cause problems in URLs
if !issuerRegex.MatchString(issuer) {
return false
}
// Check length
validLength := len(issuer) >= 1 && len(issuer) <= 100
return validLength
}
// validateOTPAuthURI validates an otpauth URI
func validateOTPAuthURI(fl validator.FieldLevel) bool {
uri := fl.Field().String()
if uri == "" {
return false
}
// Basic format check for otpauth URI
// Format: otpauth://totp/ISSUER:ACCOUNT?secret=SECRET&issuer=ISSUER&algorithm=ALGORITHM&digits=DIGITS&period=PERIOD
return otpauthRegex.MatchString(uri)
}
// validateNoXSS checks if a string contains potential XSS payloads
func validateNoXSS(fl validator.FieldLevel) bool {
value := fl.Field().String()
// 检查基本的HTML编码
if strings.Contains(value, "&#") ||
strings.Contains(value, "&lt;") ||
strings.Contains(value, "&gt;") {
return false
}
// 检查十六进制编码
if strings.Contains(strings.ToLower(value), "\\x3c") || // <
strings.Contains(strings.ToLower(value), "\\x3e") { // >
return false
}
// 检查Unicode编码
if strings.Contains(strings.ToLower(value), "\\u003c") || // <
strings.Contains(strings.ToLower(value), "\\u003e") { // >
return false
}
// 使用预编译的正则表达式检查XSS模式
for _, pattern := range xssPatterns {
if pattern.MatchString(value) {
return false
}
}
return true
}
// Request validation structs
// LoginRequest represents a login request
type LoginRequest struct {
Code string `json:"code" validate:"required"`
Code string `json:"code" validate:"required,len=6|len=8,numeric"`
}
// CreateOTPRequest represents a request to create an OTP
type CreateOTPRequest struct {
Name string `json:"name" validate:"required,min=1,max=100"`
Issuer string `json:"issuer" validate:"required,min=1,max=100"`
Name string `json:"name" validate:"required,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"required,issuer,no_xss"`
Secret string `json:"secret" validate:"required,otpsecret"`
Algorithm string `json:"algorithm" validate:"required,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"required,oneof=6 8"`
@ -146,8 +303,8 @@ type CreateOTPRequest struct {
// UpdateOTPRequest represents a request to update an OTP
type UpdateOTPRequest struct {
Name string `json:"name" validate:"omitempty,min=1,max=100"`
Issuer string `json:"issuer" validate:"omitempty,min=1,max=100"`
Name string `json:"name" validate:"omitempty,min=1,max=100,no_xss"`
Issuer string `json:"issuer" validate:"omitempty,issuer,no_xss"`
Algorithm string `json:"algorithm" validate:"omitempty,oneof=SHA1 SHA256 SHA512"`
Digits int `json:"digits" validate:"omitempty,oneof=6 8"`
Period int `json:"period" validate:"omitempty,oneof=30 60"`