package cmd import ( "context" "errors" "fmt" "log" "net/http" "os" "os/signal" "otpm/database" "otpm/handlers" "otpm/utils" "syscall" "time" "github.com/jmoiron/sqlx" "github.com/julienschmidt/httprouter" "github.com/spf13/cobra" "github.com/spf13/viper" ) var rootCmd = &cobra.Command{ Use: "otpm", Short: "otp backend for microapp on wechat", Run: func(cmd *cobra.Command, args []string) { startApp() }, } func Execute() { if err := rootCmd.Execute(); err != nil { fmt.Println(err) os.Exit(1) } } func init() { cobra.OnInitialize(initConfig) rootCmd.PersistentFlags().StringP("config", "c", "", "config file (default is $HOME/config.yaml)") rootCmd.PersistentFlags().StringP("driver", "d", "sqlite3", "database driver (sqlite3, postgres, mysql)") rootCmd.PersistentFlags().StringP("dsn", "s", "", "database connection string") rootCmd.PersistentFlags().StringP("port", "p", "8080", "port to listen on") viper.BindPFlag("database.driver", rootCmd.PersistentFlags().Lookup("driver")) viper.BindPFlag("database.dsn", rootCmd.PersistentFlags().Lookup("dsn")) viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) } func initConfig() { if cfgFile := viper.GetString("config"); cfgFile != "" { viper.SetConfigFile(cfgFile) } else { viper.AddConfigPath(".") viper.SetConfigName("config") viper.SetConfigType("yaml") } if err := viper.ReadInConfig(); err != nil { log.Fatalf("Error reading config file: %v", err) } } type App struct { db *sqlx.DB router http.Handler port int } func NewApp() (*App, error) { db, err := connectDB() if err != nil { return nil, err } if err := runMigrations(db); err != nil { return nil, err } router := setupRouter(db) port := viper.GetInt("port") return &App{ db: db, router: router, port: port, }, nil } func connectDB() (*sqlx.DB, error) { db, err := database.InitDB() if err != nil { return nil, fmt.Errorf("failed to connect to the database: %v", err) } return db, nil } func runMigrations(db *sqlx.DB) error { if err := database.MigrateDB(db); err != nil { log.Fatalf("Error migrating the database: %v", err) return fmt.Errorf("error migrating the database: %w", err) } return nil } func setupRouter(db *sqlx.DB) http.Handler { handler := &handlers.Handler{DB: db} router := httprouter.New() router.POST("/login", utils.AdaptHandler(handler.Login)) router.POST("/refresh", utils.AdaptHandler(utils.AuthMiddleware(handler.RefreshToken))) router.POST("/set", utils.AdaptHandler(utils.AuthMiddleware(handler.UpdateOrCreateOtp))) router.GET("/get", utils.AdaptHandler(utils.AuthMiddleware(handler.GetOtp))) return router } func (a *App) Start() error { server := &http.Server{Addr: fmt.Sprintf(":%d", a.port), Handler: a.router} go func() { if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Failed to start server: %v", err) } }() quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit log.Println("Shutting down server...") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() return server.Shutdown(ctx) } func startApp() { app, err := NewApp() if err != nil { log.Fatalf("Failed to initialize application: %v", err) } defer func() { if err := app.db.Close(); err != nil { log.Printf("Failed to close database connection: %v", err) } }() if err := app.Start(); err != nil { log.Fatalf("Failed to start application: %v", err) } }