From 975309639805decee0b64f5f8226393ba54752e7 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Sat, 26 Oct 2024 21:13:27 +0200 Subject: [PATCH] [no-relnote] Add app struct for nvidia-toolkit Signed-off-by: Evan Lezar --- tools/container/nvidia-toolkit/run.go | 94 ++++++++++++++++++--------- 1 file changed, 62 insertions(+), 32 deletions(-) diff --git a/tools/container/nvidia-toolkit/run.go b/tools/container/nvidia-toolkit/run.go index abc0a2cd..9ce50568 100644 --- a/tools/container/nvidia-toolkit/run.go +++ b/tools/container/nvidia-toolkit/run.go @@ -8,10 +8,10 @@ import ( "strings" "syscall" - log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" "golang.org/x/sys/unix" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime" "github.com/NVIDIA/nvidia-container-toolkit/tools/container/toolkit" ) @@ -51,12 +51,44 @@ func (o options) toolkitRoot() string { var Version = "development" func main() { - remainingArgs, root, err := ParseArgs(os.Args) + logger := logger.New() + + remainingArgs, root, err := ParseArgs(logger, os.Args) if err != nil { - log.Errorf("Error: unable to parse arguments: %v", err) + logger.Errorf("Error: unable to parse arguments: %v", err) os.Exit(1) } + c := NewApp(logger, root) + + // Run the CLI + logger.Infof("Starting %v", c.Name) + if err := c.Run(remainingArgs); err != nil { + logger.Errorf("error running nvidia-toolkit: %v", err) + os.Exit(1) + } + + logger.Infof("Completed %v", c.Name) +} + +// An app represents the nvidia-ctk-installer. +type app struct { + logger logger.Interface + // defaultRoot stores the root to use if the --root flag is not specified. + defaultRoot string +} + +// NewApp creates the CLI app fro the specified options. +// defaultRoot is used as the root if not specified via the --root flag. +func NewApp(logger logger.Interface, defaultRoot string) *cli.App { + a := app{ + logger: logger, + defaultRoot: defaultRoot, + } + return a.build() +} + +func (a app) build() *cli.App { options := options{ toolkitOptions: toolkit.Options{}, } @@ -68,10 +100,10 @@ func main() { c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Version = Version c.Before = func(ctx *cli.Context) error { - return validateFlags(ctx, &options) + return a.Before(ctx, &options) } c.Action = func(ctx *cli.Context) error { - return Run(ctx, &options) + return a.Run(ctx, &options) } // Setup flags for the CLI @@ -102,7 +134,7 @@ func main() { }, &cli.StringFlag{ Name: "root", - Value: root, + Value: a.defaultRoot, Usage: "the folder where the NVIDIA Container Toolkit is to be installed. It will be installed to `ROOT`/toolkit", Destination: &options.root, EnvVars: []string{"ROOT"}, @@ -119,17 +151,14 @@ func main() { c.Flags = append(c.Flags, toolkit.Flags(&options.toolkitOptions)...) c.Flags = append(c.Flags, runtime.Flags(&options.runtimeOptions)...) - // Run the CLI - log.Infof("Starting %v", c.Name) - if err := c.Run(remainingArgs); err != nil { - log.Errorf("error running nvidia-toolkit: %v", err) - os.Exit(1) - } - - log.Infof("Completed %v", c.Name) + return c } -func validateFlags(_ *cli.Context, o *options) error { +func (a *app) Before(c *cli.Context, o *options) error { + return a.validateFlags(c, o) +} + +func (a *app) validateFlags(_ *cli.Context, o *options) error { if o.root == "" { return fmt.Errorf("the install root must be specified") } @@ -139,6 +168,7 @@ func validateFlags(_ *cli.Context, o *options) error { if filepath.Base(o.pidFile) != toolkitPidFilename { return fmt.Errorf("invalid toolkit.pid path %v", o.pidFile) } + if err := toolkit.ValidateOptions(&o.toolkitOptions, o.toolkitRoot()); err != nil { return err } @@ -149,12 +179,12 @@ func validateFlags(_ *cli.Context, o *options) error { } // Run runs the core logic of the CLI -func Run(c *cli.Context, o *options) error { - err := initialize(o.pidFile) +func (a *app) Run(c *cli.Context, o *options) error { + err := a.initialize(o.pidFile) if err != nil { return fmt.Errorf("unable to initialize: %v", err) } - defer shutdown(o.pidFile) + defer a.shutdown(o.pidFile) if len(o.toolkitOptions.ContainerRuntimeRuntimes.Value()) == 0 { lowlevelRuntimePaths, err := runtime.GetLowlevelRuntimePaths(&o.runtimeOptions, o.runtime) @@ -176,7 +206,7 @@ func Run(c *cli.Context, o *options) error { } if !o.noDaemon { - err = waitForSignal() + err = a.waitForSignal() if err != nil { return fmt.Errorf("unable to wait for signal: %v", err) } @@ -192,8 +222,8 @@ func Run(c *cli.Context, o *options) error { // ParseArgs checks if a single positional argument was defined and extracts this the root. // If no positional arguments are defined, it is assumed that the root is specified as a flag. -func ParseArgs(args []string) ([]string, string, error) { - log.Infof("Parsing arguments") +func ParseArgs(logger logger.Interface, args []string) ([]string, string, error) { + logger.Infof("Parsing arguments") if len(args) < 2 { return args, "", nil @@ -218,8 +248,8 @@ func ParseArgs(args []string) ([]string, string, error) { return nil, "", fmt.Errorf("unexpected positional argument(s) %v", args[2:lastPositionalArg+1]) } -func initialize(pidFile string) error { - log.Infof("Initializing") +func (a *app) initialize(pidFile string) error { + a.logger.Infof("Initializing") if dir := filepath.Dir(pidFile); dir != "" { err := os.MkdirAll(dir, 0755) @@ -235,8 +265,8 @@ func initialize(pidFile string) error { err = unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB) if err != nil { - log.Warningf("Unable to get exclusive lock on '%v'", pidFile) - log.Warningf("This normally means an instance of the NVIDIA toolkit Container is already running, aborting") + a.logger.Warningf("Unable to get exclusive lock on '%v'", pidFile) + a.logger.Warningf("This normally means an instance of the NVIDIA toolkit Container is already running, aborting") return fmt.Errorf("unable to get flock on pidfile: %v", err) } @@ -253,8 +283,8 @@ func initialize(pidFile string) error { case <-waitingForSignal: signalReceived <- true default: - log.Infof("Signal received, exiting early") - shutdown(pidFile) + a.logger.Infof("Signal received, exiting early") + a.shutdown(pidFile) os.Exit(0) } }() @@ -262,18 +292,18 @@ func initialize(pidFile string) error { return nil } -func waitForSignal() error { - log.Infof("Waiting for signal") +func (a *app) waitForSignal() error { + a.logger.Infof("Waiting for signal") waitingForSignal <- true <-signalReceived return nil } -func shutdown(pidFile string) { - log.Infof("Shutting Down") +func (a *app) shutdown(pidFile string) { + a.logger.Infof("Shutting Down") err := os.Remove(pidFile) if err != nil { - log.Warningf("Unable to remove pidfile: %v", err) + a.logger.Warningf("Unable to remove pidfile: %v", err) } }