Allow toolkit.pid path to be specified

This change makes the following changes:
* Allows the toolkit.pid path to be specified
* Creates the toolkit.pid file at /run/nvidia/toolkit/toolkit.pid by default
* Handles failures to remove the /run/nvidia/toolkit folder

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2024-06-14 14:15:54 +02:00
parent ac90b7963d
commit 876d479308
2 changed files with 62 additions and 16 deletions

View File

@ -15,8 +15,8 @@ import (
) )
const ( const (
runDir = "/run/nvidia" toolkitPidFilename = "toolkit.pid"
pidFile = runDir + "/toolkit.pid" defaultPidFile = "/run/nvidia/toolkit/" + toolkitPidFilename
toolkitCommand = "toolkit" toolkitCommand = "toolkit"
toolkitSubDir = "toolkit" toolkitSubDir = "toolkit"
@ -36,6 +36,7 @@ type options struct {
runtime string runtime string
runtimeArgs string runtimeArgs string
root string root string
pidFile string
} }
// Version defines the CLI version. This is set at build time using LD FLAGS // Version defines the CLI version. This is set at build time using LD FLAGS
@ -56,6 +57,9 @@ func main() {
c.UsageText = "[DESTINATION] [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" c.UsageText = "[DESTINATION] [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]"
c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit"
c.Version = Version c.Version = Version
c.Before = func(ctx *cli.Context) error {
return validateFlags(ctx, &options)
}
c.Action = func(ctx *cli.Context) error { c.Action = func(ctx *cli.Context) error {
return Run(ctx, &options) return Run(ctx, &options)
} }
@ -92,6 +96,13 @@ func main() {
Destination: &options.root, Destination: &options.root,
EnvVars: []string{"ROOT"}, EnvVars: []string{"ROOT"},
}, },
&cli.StringFlag{
Name: "pid-file",
Value: defaultPidFile,
Usage: "the path to a toolkit.pid file to ensure that only a single configuration instance is running",
Destination: &options.pidFile,
EnvVars: []string{"TOOLKIT_PID_FILE", "PID_FILE"},
},
} }
// Run the CLI // Run the CLI
@ -104,6 +115,14 @@ func main() {
log.Infof("Completed %v", c.Name) log.Infof("Completed %v", c.Name)
} }
func validateFlags(_ *cli.Context, o *options) error {
if filepath.Base(o.pidFile) != toolkitPidFilename {
return fmt.Errorf("invalid toolkit.pid path %v", o.pidFile)
}
return nil
}
// Run runs the core logic of the CLI // Run runs the core logic of the CLI
func Run(c *cli.Context, o *options) error { func Run(c *cli.Context, o *options) error {
err := verifyFlags(o) err := verifyFlags(o)
@ -111,11 +130,11 @@ func Run(c *cli.Context, o *options) error {
return fmt.Errorf("unable to verify flags: %v", err) return fmt.Errorf("unable to verify flags: %v", err)
} }
err = initialize() err = initialize(o.pidFile)
if err != nil { if err != nil {
return fmt.Errorf("unable to initialize: %v", err) return fmt.Errorf("unable to initialize: %v", err)
} }
defer shutdown() defer shutdown(o.pidFile)
err = installToolkit(o) err = installToolkit(o)
if err != nil { if err != nil {
@ -182,9 +201,16 @@ func verifyFlags(o *options) error {
return nil return nil
} }
func initialize() error { func initialize(pidFile string) error {
log.Infof("Initializing") log.Infof("Initializing")
if dir := filepath.Dir(pidFile); dir != "" {
err := os.MkdirAll(dir, 0755)
if err != nil {
return fmt.Errorf("unable to create folder for pidfile: %w", err)
}
}
f, err := os.Create(pidFile) f, err := os.Create(pidFile)
if err != nil { if err != nil {
return fmt.Errorf("unable to create pidfile: %v", err) return fmt.Errorf("unable to create pidfile: %v", err)
@ -211,7 +237,7 @@ func initialize() error {
signalReceived <- true signalReceived <- true
default: default:
log.Infof("Signal received, exiting early") log.Infof("Signal received, exiting early")
shutdown() shutdown(pidFile)
os.Exit(0) os.Exit(0)
} }
}() }()
@ -286,7 +312,7 @@ func cleanupRuntime(o *options) error {
return nil return nil
} }
func shutdown() { func shutdown(pidFile string) {
log.Infof("Shutting Down") log.Infof("Shutting Down")
err := os.Remove(pidFile) err := os.Remove(pidFile)

View File

@ -17,6 +17,7 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -44,6 +45,8 @@ const (
nvidiaContainerToolkitConfigSource = "/etc/nvidia-container-runtime/config.toml" nvidiaContainerToolkitConfigSource = "/etc/nvidia-container-runtime/config.toml"
configFilename = "config.toml" configFilename = "config.toml"
toolkitPidFilename = "toolkit.pid"
) )
type options struct { type options struct {
@ -111,7 +114,7 @@ func main() {
return validateOptions(c, &opts) return validateOptions(c, &opts)
} }
delete.Action = func(c *cli.Context) error { delete.Action = func(c *cli.Context) error {
return Delete(c, &opts) return TryDelete(c, &opts)
} }
// Register the subcommand with the top-level CLI // Register the subcommand with the top-level CLI
@ -301,12 +304,29 @@ func validateOptions(c *cli.Context, opts *options) error {
return nil return nil
} }
// Delete removes the NVIDIA container toolkit // TryDelete attempts to remove the specified toolkit folder.
func Delete(cli *cli.Context, opts *options) error { // A toolkit.pid file -- if present -- is skipped.
log.Infof("Deleting NVIDIA container toolkit from '%v'", opts.toolkitRoot) func TryDelete(cli *cli.Context, opts *options) error {
err := os.RemoveAll(opts.toolkitRoot) log.Infof("Attempting to delete NVIDIA container toolkit from '%v'", opts.toolkitRoot)
if err != nil {
return fmt.Errorf("error deleting toolkit directory: %v", err) contents, err := os.ReadDir(opts.toolkitRoot)
if err != nil && errors.Is(err, os.ErrNotExist) {
return nil
} else if err != nil {
return fmt.Errorf("failed to read the contents of %v: %w", opts.toolkitRoot, err)
}
for _, content := range contents {
if content.Name() == toolkitPidFilename {
continue
}
name := filepath.Join(opts.toolkitRoot, content.Name())
if err := os.RemoveAll(name); err != nil {
log.Warningf("could not remove %v: %v", name, err)
}
}
if err := os.RemoveAll(opts.toolkitRoot); err != nil {
log.Warningf("could not remove %v: %v", opts.toolkitRoot, err)
} }
return nil return nil
} }