diff --git a/tools/container/nvidia-toolkit/run.go b/tools/container/nvidia-toolkit/run.go index 87d5b111..0eea1e68 100644 --- a/tools/container/nvidia-toolkit/run.go +++ b/tools/container/nvidia-toolkit/run.go @@ -15,10 +15,10 @@ import ( ) const ( - runDir = "/run/nvidia" - pidFile = runDir + "/toolkit.pid" - toolkitCommand = "toolkit" - toolkitSubDir = "toolkit" + toolkitPidFilename = "toolkit.pid" + defaultPidFile = "/run/nvidia/toolkit/" + toolkitPidFilename + toolkitCommand = "toolkit" + toolkitSubDir = "toolkit" defaultToolkitArgs = "" defaultRuntime = "docker" @@ -36,6 +36,7 @@ type options struct { runtime string runtimeArgs string root string + pidFile string } // Version defines the CLI version. This is set at build time using LD FLAGS @@ -56,6 +57,9 @@ func main() { c.UsageText = "[DESTINATION] [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Version = Version + c.Before = func(ctx *cli.Context) error { + return validateFlags(ctx, &options) + } c.Action = func(ctx *cli.Context) error { return Run(ctx, &options) } @@ -92,6 +96,13 @@ func main() { Destination: &options.root, EnvVars: []string{"ROOT"}, }, + &cli.StringFlag{ + Name: "pid-file", + Value: defaultPidFile, + Usage: "the path to a toolkit.pid file to ensure that only a single configuration instance is running", + Destination: &options.pidFile, + EnvVars: []string{"TOOLKIT_PID_FILE", "PID_FILE"}, + }, } // Run the CLI @@ -104,6 +115,14 @@ func main() { log.Infof("Completed %v", c.Name) } +func validateFlags(_ *cli.Context, o *options) error { + if filepath.Base(o.pidFile) != toolkitPidFilename { + return fmt.Errorf("invalid toolkit.pid path %v", o.pidFile) + } + + return nil +} + // Run runs the core logic of the CLI func Run(c *cli.Context, o *options) error { err := verifyFlags(o) @@ -111,11 +130,11 @@ func Run(c *cli.Context, o *options) error { return fmt.Errorf("unable to verify flags: %v", err) } - err = initialize() + err = initialize(o.pidFile) if err != nil { return fmt.Errorf("unable to initialize: %v", err) } - defer shutdown() + defer shutdown(o.pidFile) err = installToolkit(o) if err != nil { @@ -182,9 +201,16 @@ func verifyFlags(o *options) error { return nil } -func initialize() error { +func initialize(pidFile string) error { log.Infof("Initializing") + if dir := filepath.Dir(pidFile); dir != "" { + err := os.MkdirAll(dir, 0755) + if err != nil { + return fmt.Errorf("unable to create folder for pidfile: %w", err) + } + } + f, err := os.Create(pidFile) if err != nil { return fmt.Errorf("unable to create pidfile: %v", err) @@ -211,7 +237,7 @@ func initialize() error { signalReceived <- true default: log.Infof("Signal received, exiting early") - shutdown() + shutdown(pidFile) os.Exit(0) } }() @@ -286,7 +312,7 @@ func cleanupRuntime(o *options) error { return nil } -func shutdown() { +func shutdown(pidFile string) { log.Infof("Shutting Down") err := os.Remove(pidFile) diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index 9e161d63..8175ed4e 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -17,6 +17,7 @@ package main import ( + "errors" "fmt" "io" "os" @@ -44,6 +45,8 @@ const ( nvidiaContainerToolkitConfigSource = "/etc/nvidia-container-runtime/config.toml" configFilename = "config.toml" + + toolkitPidFilename = "toolkit.pid" ) type options struct { @@ -111,7 +114,7 @@ func main() { return validateOptions(c, &opts) } delete.Action = func(c *cli.Context) error { - return Delete(c, &opts) + return TryDelete(c, &opts) } // Register the subcommand with the top-level CLI @@ -301,12 +304,29 @@ func validateOptions(c *cli.Context, opts *options) error { return nil } -// Delete removes the NVIDIA container toolkit -func Delete(cli *cli.Context, opts *options) error { - log.Infof("Deleting NVIDIA container toolkit from '%v'", opts.toolkitRoot) - err := os.RemoveAll(opts.toolkitRoot) - if err != nil { - return fmt.Errorf("error deleting toolkit directory: %v", err) +// TryDelete attempts to remove the specified toolkit folder. +// A toolkit.pid file -- if present -- is skipped. +func TryDelete(cli *cli.Context, opts *options) error { + log.Infof("Attempting to delete NVIDIA container toolkit from '%v'", opts.toolkitRoot) + + contents, err := os.ReadDir(opts.toolkitRoot) + if err != nil && errors.Is(err, os.ErrNotExist) { + return nil + } else if err != nil { + return fmt.Errorf("failed to read the contents of %v: %w", opts.toolkitRoot, err) + } + + for _, content := range contents { + if content.Name() == toolkitPidFilename { + continue + } + name := filepath.Join(opts.toolkitRoot, content.Name()) + if err := os.RemoveAll(name); err != nil { + log.Warningf("could not remove %v: %v", name, err) + } + } + if err := os.RemoveAll(opts.toolkitRoot); err != nil { + log.Warningf("could not remove %v: %v", opts.toolkitRoot, err) } return nil }