diff --git a/tools/container/nvidia-toolkit/run.go b/tools/container/nvidia-toolkit/run.go index 405b8e32..fb337e74 100644 --- a/tools/container/nvidia-toolkit/run.go +++ b/tools/container/nvidia-toolkit/run.go @@ -31,14 +31,21 @@ var waitingForSignal = make(chan bool, 1) var signalReceived = make(chan bool, 1) var destinationArg string -var noDaemonFlag bool -var runtimeFlag string -var runtimeArgsFlag string + +// options stores the command line arguments +type options struct { + noDaemon bool + runtime string + runtimeArgs string +} // Version defines the CLI version. This is set at build time using LD FLAGS var Version = "development" func main() { + + options := options{} + // Create the top-level CLI c := cli.NewApp() c.Name = "nvidia-toolkit" @@ -46,7 +53,9 @@ func main() { c.UsageText = "DESTINATION [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Version = Version - c.Action = Run + c.Action = func(ctx *cli.Context) error { + return Run(ctx, &options) + } // Setup flags for the CLI c.Flags = []cli.Flag{ @@ -54,7 +63,7 @@ func main() { Name: "no-daemon", Aliases: []string{"n"}, Usage: "terminate immediatly after setting up the runtime. Note that no cleanup will be performed", - Destination: &noDaemonFlag, + Destination: &options.noDaemon, EnvVars: []string{"NO_DAEMON"}, }, &cli.StringFlag{ @@ -62,7 +71,7 @@ func main() { Aliases: []string{"r"}, Usage: "the runtime to setup on this node. One of {'docker', 'crio', 'containerd'}", Value: defaultRuntime, - Destination: &runtimeFlag, + Destination: &options.runtime, EnvVars: []string{"RUNTIME"}, }, &cli.StringFlag{ @@ -70,7 +79,7 @@ func main() { Aliases: []string{"u"}, Usage: "arguments to pass to 'docker', 'crio', or 'containerd' setup command", Value: defaultRuntimeArgs, - Destination: &runtimeArgsFlag, + Destination: &options.runtimeArgs, EnvVars: []string{"RUNTIME_ARGS"}, }, } @@ -93,8 +102,8 @@ func main() { } // Run runs the core logic of the CLI -func Run(c *cli.Context) error { - err := verifyFlags() +func Run(c *cli.Context, o *options) error { + err := verifyFlags(o) if err != nil { return fmt.Errorf("unable to verify flags: %v", err) } @@ -110,18 +119,18 @@ func Run(c *cli.Context) error { return fmt.Errorf("unable to install toolkit: %v", err) } - err = setupRuntime() + err = setupRuntime(o) if err != nil { return fmt.Errorf("unable to setup runtime: %v", err) } - if !noDaemonFlag { + if !o.noDaemon { err = waitForSignal() if err != nil { return fmt.Errorf("unable to wait for signal: %v", err) } - err = cleanupRuntime() + err = cleanupRuntime(o) if err != nil { return fmt.Errorf("unable to cleanup runtime: %v", err) } @@ -166,10 +175,10 @@ func ParseArgs(args []string) ([]string, error) { return append([]string{args[0]}, args[numPositionalArgs:]...), nil } -func verifyFlags() error { +func verifyFlags(o *options) error { log.Infof("Verifying Flags") - if _, exists := availableRuntimes[runtimeFlag]; !exists { - return fmt.Errorf("unknown runtime: %v", runtimeFlag) + if _, exists := availableRuntimes[o.runtime]; !exists { + return fmt.Errorf("unknown runtime: %v", o.runtime) } return nil } @@ -232,19 +241,19 @@ func installToolkit() error { return nil } -func setupRuntime() error { +func setupRuntime(o *options) error { toolkitDir := filepath.Join(destinationArg, toolkitSubDir) log.Infof("Setting up runtime") - cmdline := fmt.Sprintf("%v setup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) + cmdline := fmt.Sprintf("%v setup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir) cmd := exec.Command("sh", "-c", cmdline) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { - return fmt.Errorf("error running %v command: %v", runtimeFlag, err) + return fmt.Errorf("error running %v command: %v", o.runtime, err) } return nil @@ -257,19 +266,19 @@ func waitForSignal() error { return nil } -func cleanupRuntime() error { +func cleanupRuntime(o *options) error { toolkitDir := filepath.Join(destinationArg, toolkitSubDir) log.Infof("Cleaning up Runtime") - cmdline := fmt.Sprintf("%v cleanup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) + cmdline := fmt.Sprintf("%v cleanup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir) cmd := exec.Command("sh", "-c", cmdline) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { - return fmt.Errorf("error running %v command: %v", runtimeFlag, err) + return fmt.Errorf("error running %v command: %v", o.runtime, err) } return nil