From 1a83552aa67d2ff8633f8f2f6dc25c15e5d10bf9 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 30 Oct 2024 13:36:11 +0100 Subject: [PATCH] [no-relnote] Use urfave for nvidia-container-runtime-hook CLI Signed-off-by: Evan Lezar --- .../hook_config.go | 10 +- .../hook_config_test.go | 7 +- cmd/nvidia-container-runtime-hook/main.go | 112 ++++++++++-------- 3 files changed, 71 insertions(+), 58 deletions(-) diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index afdb8bbc..beb8089f 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -30,11 +30,11 @@ func getDefaultHookConfig() (HookConfig, error) { } // loadConfig loads the required paths for the hook config. -func loadConfig() (*config.Config, error) { +func (a *app) loadConfig() (*config.Config, error) { var configPaths []string var required bool - if len(*configflag) != 0 { - configPaths = append(configPaths, *configflag) + if len(a.configFile) != 0 { + configPaths = append(configPaths, a.configFile) required = true } else { configPaths = append(configPaths, path.Join(driverPath, configPath), configPath) @@ -56,8 +56,8 @@ func loadConfig() (*config.Config, error) { return config.GetDefault() } -func getHookConfig() (*HookConfig, error) { - cfg, err := loadConfig() +func (a *app) getHookConfig() (*HookConfig, error) { + cfg, err := a.loadConfig() if err != nil { return nil, fmt.Errorf("failed to load config: %v", err) } diff --git a/cmd/nvidia-container-runtime-hook/hook_config_test.go b/cmd/nvidia-container-runtime-hook/hook_config_test.go index 7c50ec12..f37c6916 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config_test.go +++ b/cmd/nvidia-container-runtime-hook/hook_config_test.go @@ -72,16 +72,17 @@ func TestGetHookConfig(t *testing.T) { if len(filename) > 0 { os.Remove(filename) } - configflag = nil }() + a := &app{} + if tc.lines != nil { configFile, err := os.CreateTemp("", "*.toml") require.NoError(t, err) defer configFile.Close() filename = configFile.Name() - configflag = &filename + a.configFile = filename for _, line := range tc.lines { _, err := configFile.WriteString(fmt.Sprintf("%s\n", line)) @@ -91,7 +92,7 @@ func TestGetHookConfig(t *testing.T) { var config HookConfig getHookConfig := func() { - c, _ := getHookConfig() + c, _ := a.getHookConfig() config = *c } diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index a0cfe3ec..cbaf9e23 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -1,7 +1,7 @@ package main import ( - "flag" + "errors" "fmt" "log" "os" @@ -13,29 +13,26 @@ import ( "strings" "syscall" + cli "github.com/urfave/cli/v2" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" ) -var ( - debugflag = flag.Bool("debug", false, "enable debug output") - versionflag = flag.Bool("version", false, "enable version output") - configflag = flag.String("config", "", "configuration file") -) - -func exit() { +func (a *app) recoverIfRequired() error { if err := recover(); err != nil { - if _, ok := err.(runtime.Error); ok { + rerr, ok := err.(runtime.Error) + if ok { log.Println(err) } - if *debugflag { + if a.isDebug { log.Printf("%s", debug.Stack()) } - os.Exit(1) + return rerr } - os.Exit(0) + return nil } func getCLIPath(config config.ContainerCLIConfig) string { @@ -63,15 +60,15 @@ func getRootfsPath(config containerConfig) string { return rootfs } -func doPrestart() { - var err error - - defer exit() +func (a *app) doPrestart() (rerr error) { + defer func() { + rerr = errors.Join(rerr, a.recoverIfRequired()) + }() log.SetFlags(0) - hook, err := getHookConfig() + hook, err := a.getHookConfig() if err != nil || hook == nil { - log.Panicln("error getting hook config:", err) + return fmt.Errorf("error getting hook config: %w", err) } cli := hook.NVIDIAContainerCLIConfig @@ -79,11 +76,11 @@ func doPrestart() { nvidia := container.Nvidia if nvidia == nil { // Not a GPU container, nothing to do. - return + return nil } if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" { - log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") + return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead") } rootfs := getRootfsPath(container) @@ -101,7 +98,7 @@ func doPrestart() { if cli.NoPivot { args = append(args, "--no-pivot") } - if *debugflag { + if a.isDebug { args = append(args, "--debug=/dev/stderr") } else if cli.Debug != "" { args = append(args, fmt.Sprintf("--debug=%s", cli.Debug)) @@ -149,45 +146,60 @@ func doPrestart() { env := append(os.Environ(), cli.Environment...) //nolint:gosec // TODO: Can we harden this so that there is less risk of command injection? - err = syscall.Exec(args[0], args, env) - log.Panicln("exec failed:", err) + return syscall.Exec(args[0], args, env) } -func usage() { - fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, "\nCommands:\n") - fmt.Fprintf(os.Stderr, " prestart\n run the prestart hook\n") - fmt.Fprintf(os.Stderr, " poststart\n no-op\n") - fmt.Fprintf(os.Stderr, " poststop\n no-op\n") +type options struct { + isDebug bool + configFile string +} +type app struct { + options } func main() { - flag.Usage = usage - flag.Parse() + a := &app{} + // Create the top-level CLI + c := cli.NewApp() + c.Name = "NVIDIA Container Runtime Hook" + c.Version = info.GetVersionString() - if *versionflag { - fmt.Printf("%v version %v\n", "NVIDIA Container Runtime Hook", info.GetVersionString()) - return + c.Flags = []cli.Flag{ + &cli.BoolFlag{ + Name: "debug", + Destination: &a.isDebug, + Usage: "Enabled debug output", + }, + &cli.StringFlag{ + Name: "config", + Destination: &a.configFile, + Usage: "The path to the configuration file to use", + }, } - args := flag.Args() - if len(args) == 0 { - flag.Usage() - os.Exit(2) + c.Commands = []*cli.Command{ + { + Name: "prestart", + Usage: "run the prestart hook", + Action: func(ctx *cli.Context) error { + return a.doPrestart() + }, + }, + { + Name: "poststart", + Aliases: []string{"poststop"}, + Usage: "no-op", + Action: func(ctx *cli.Context) error { + return nil + }, + }, } + c.DefaultCommand = "prestart" - switch args[0] { - case "prestart": - doPrestart() - os.Exit(0) - case "poststart": - fallthrough - case "poststop": - os.Exit(0) - default: - flag.Usage() - os.Exit(2) + // Run the CLI + err := c.Run(os.Args) + if err != nil { + os.Exit(1) } }