diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index 4f215942..4b477e69 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -40,12 +40,18 @@ const ( ) var toolkitDirArg string -var nvidiaDriverRootFlag string -var nvidiaContainerRuntimeDebugFlag string -var nvidiaContainerRuntimeLogLevelFlag string -var nvidiaContainerCLIDebugFlag string + +type options struct { + DriverRoot string + ContainerRuntimeDebug string + ContainerRuntimeLogLevel string + ContainerCLIDebug string +} func main() { + + opts := options{} + // Create the top-level CLI c := cli.NewApp() c.Name = "toolkit" @@ -58,7 +64,9 @@ func main() { install.Usage = "Install the components of the NVIDIA container toolkit" install.ArgsUsage = "" install.Before = parseArgs - install.Action = Install + install.Action = func(c *cli.Context) error { + return Install(c, &opts) + } // Create the 'delete' command delete := cli.Command{} @@ -78,24 +86,24 @@ func main() { &cli.StringFlag{ Name: "nvidia-driver-root", Value: DefaultNvidiaDriverRoot, - Destination: &nvidiaDriverRootFlag, + Destination: &opts.DriverRoot, EnvVars: []string{"NVIDIA_DRIVER_ROOT"}, }, &cli.StringFlag{ Name: "nvidia-container-runtime-debug", Usage: "Specify the location of the debug log file for the NVIDIA Container Runtime", - Destination: &nvidiaContainerRuntimeDebugFlag, + Destination: &opts.ContainerRuntimeDebug, EnvVars: []string{"NVIDIA_CONTAINER_RUNTIME_DEBUG"}, }, &cli.StringFlag{ Name: "nvidia-container-runtime-debug-log-level", - Destination: &nvidiaContainerRuntimeLogLevelFlag, + Destination: &opts.ContainerRuntimeLogLevel, EnvVars: []string{"NVIDIA_CONTAINER_RUNTIME_LOG_LEVEL"}, }, &cli.StringFlag{ Name: "nvidia-container-cli-debug", Usage: "Specify the location of the debug log file for the NVIDIA Container CLI", - Destination: &nvidiaContainerCLIDebugFlag, + Destination: &opts.ContainerCLIDebug, EnvVars: []string{"NVIDIA_CONTAINER_CLI_DEBUG"}, }, } @@ -135,7 +143,7 @@ func Delete(cli *cli.Context) error { // Install installs the components of the NVIDIA container toolkit. // Any existing installation is removed. -func Install(cli *cli.Context) error { +func Install(cli *cli.Context, opts *options) error { log.Infof("Installing NVIDIA container toolkit to '%v'", toolkitDirArg) log.Infof("Removing existing NVIDIA container toolkit installation") @@ -157,7 +165,7 @@ func Install(cli *cli.Context) error { return fmt.Errorf("error installing NVIDIA container library: %v", err) } - err = installContainerRuntimes(toolkitDirArg, nvidiaDriverRootFlag) + err = installContainerRuntimes(toolkitDirArg, opts.DriverRoot) if err != nil { return fmt.Errorf("error installing NVIDIA container runtime: %v", err) } @@ -172,7 +180,7 @@ func Install(cli *cli.Context) error { return fmt.Errorf("error installing NVIDIA container runtime hook: %v", err) } - err = installToolkitConfig(toolkitConfigPath, nvidiaDriverRootFlag, nvidiaContainerCliExecutable) + err = installToolkitConfig(toolkitConfigPath, nvidiaContainerCliExecutable, opts) if err != nil { return fmt.Errorf("error installing NVIDIA container toolkit config: %v", err) } @@ -230,7 +238,7 @@ func installLibrary(libName string, toolkitDir string) error { // installToolkitConfig installs the config file for the NVIDIA container toolkit ensuring // that the settings are updated to match the desired install and nvidia driver directories. -func installToolkitConfig(toolkitConfigPath string, nvidiaDriverDir string, nvidiaContainerCliExecutablePath string) error { +func installToolkitConfig(toolkitConfigPath string, nvidiaContainerCliExecutablePath string, opts *options) error { log.Infof("Installing NVIDIA container toolkit config '%v'", toolkitConfigPath) config, err := toml.LoadFile(nvidiaContainerToolkitConfigSource) @@ -253,17 +261,17 @@ func installToolkitConfig(toolkitConfigPath string, nvidiaDriverDir string, nvid ldconfigPath := fmt.Sprintf("%s", config.GetPath(nvidiaContainerCliKey("ldconfig"))) // Use the driver run root as the root: - driverLdconfigPath := "@" + filepath.Join(nvidiaDriverDir, strings.TrimPrefix(ldconfigPath, "@/")) + driverLdconfigPath := "@" + filepath.Join(opts.DriverRoot, strings.TrimPrefix(ldconfigPath, "@/")) - config.SetPath(nvidiaContainerCliKey("root"), nvidiaDriverDir) + config.SetPath(nvidiaContainerCliKey("root"), opts.DriverRoot) config.SetPath(nvidiaContainerCliKey("path"), nvidiaContainerCliExecutablePath) config.SetPath(nvidiaContainerCliKey("ldconfig"), driverLdconfigPath) // Set the debug options if selected debugOptions := map[string]string{ - "nvidia-container-runtime.debug": nvidiaContainerRuntimeDebugFlag, - "nvidia-container-runtime.log-level": nvidiaContainerRuntimeLogLevelFlag, - "nvidia-container-cli.debug": nvidiaContainerCLIDebugFlag, + "nvidia-container-runtime.debug": opts.ContainerRuntimeDebug, + "nvidia-container-runtime.log-level": opts.ContainerRuntimeLogLevel, + "nvidia-container-cli.debug": opts.ContainerCLIDebug, } for key, value := range debugOptions { if value == "" {