diff --git a/tools/container/nvidia-toolkit/run.go b/tools/container/nvidia-toolkit/run.go index 405b8e32..aff9dd40 100644 --- a/tools/container/nvidia-toolkit/run.go +++ b/tools/container/nvidia-toolkit/run.go @@ -30,23 +30,35 @@ var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "container var waitingForSignal = make(chan bool, 1) var signalReceived = make(chan bool, 1) -var destinationArg string -var noDaemonFlag bool -var runtimeFlag string -var runtimeArgsFlag string +// options stores the command line arguments +type options struct { + noDaemon bool + runtime string + runtimeArgs string + root string +} // Version defines the CLI version. This is set at build time using LD FLAGS var Version = "development" func main() { + remainingArgs, root, err := ParseArgs(os.Args) + if err != nil { + log.Errorf("Error: unable to parse arguments: %v", err) + os.Exit(1) + } + + options := options{} // Create the top-level CLI c := cli.NewApp() c.Name = "nvidia-toolkit" c.Usage = "Install the nvidia-container-toolkit for use by a given runtime" - c.UsageText = "DESTINATION [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" + c.UsageText = "[DESTINATION] [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Version = Version - c.Action = Run + c.Action = func(ctx *cli.Context) error { + return Run(ctx, &options) + } // Setup flags for the CLI c.Flags = []cli.Flag{ @@ -54,7 +66,7 @@ func main() { Name: "no-daemon", Aliases: []string{"n"}, Usage: "terminate immediatly after setting up the runtime. Note that no cleanup will be performed", - Destination: &noDaemonFlag, + Destination: &options.noDaemon, EnvVars: []string{"NO_DAEMON"}, }, &cli.StringFlag{ @@ -62,7 +74,7 @@ func main() { Aliases: []string{"r"}, Usage: "the runtime to setup on this node. One of {'docker', 'crio', 'containerd'}", Value: defaultRuntime, - Destination: &runtimeFlag, + Destination: &options.runtime, EnvVars: []string{"RUNTIME"}, }, &cli.StringFlag{ @@ -70,20 +82,20 @@ func main() { Aliases: []string{"u"}, Usage: "arguments to pass to 'docker', 'crio', or 'containerd' setup command", Value: defaultRuntimeArgs, - Destination: &runtimeArgsFlag, + Destination: &options.runtimeArgs, EnvVars: []string{"RUNTIME_ARGS"}, }, + &cli.StringFlag{ + Name: "root", + Value: root, + Usage: "the folder where the NVIDIA Container Toolkit is to be installed. It will be installed to `ROOT`/toolkit", + Destination: &options.root, + EnvVars: []string{"ROOT"}, + }, } // Run the CLI log.Infof("Starting %v", c.Name) - - remainingArgs, err := ParseArgs(os.Args) - if err != nil { - log.Errorf("Error: unable to parse arguments: %v", err) - os.Exit(1) - } - if err := c.Run(remainingArgs); err != nil { log.Errorf("error running nvidia-toolkit: %v", err) os.Exit(1) @@ -93,8 +105,8 @@ func main() { } // Run runs the core logic of the CLI -func Run(c *cli.Context) error { - err := verifyFlags() +func Run(c *cli.Context, o *options) error { + err := verifyFlags(o) if err != nil { return fmt.Errorf("unable to verify flags: %v", err) } @@ -105,23 +117,23 @@ func Run(c *cli.Context) error { } defer shutdown() - err = installToolkit() + err = installToolkit(o) if err != nil { return fmt.Errorf("unable to install toolkit: %v", err) } - err = setupRuntime() + err = setupRuntime(o) if err != nil { return fmt.Errorf("unable to setup runtime: %v", err) } - if !noDaemonFlag { + if !o.noDaemon { err = waitForSignal() if err != nil { return fmt.Errorf("unable to wait for signal: %v", err) } - err = cleanupRuntime() + err = cleanupRuntime(o) if err != nil { return fmt.Errorf("unable to cleanup runtime: %v", err) } @@ -130,46 +142,42 @@ func Run(c *cli.Context) error { return nil } -// ParseArgs parses the command line arguments and returns the remaining arguments -func ParseArgs(args []string) ([]string, error) { +// ParseArgs checks if a single positional argument was defined and extracts this the root. +// If no positional arguments are defined, the it is assumed that the root is specified as a flag. +func ParseArgs(args []string) ([]string, string, error) { log.Infof("Parsing arguments") - numPositionalArgs := 2 // Includes command itself - - if len(args) < numPositionalArgs { - return nil, fmt.Errorf("missing arguments") + if len(args) < 2 { + return args, "", nil } - for _, arg := range args { - if arg == "--help" || arg == "-h" { - return []string{args[0], arg}, nil - } - if arg == "--version" || arg == "-v" { - return []string{args[0], arg}, nil - } - } - - for _, arg := range args[:numPositionalArgs] { + var lastPositionalArg int + for i, arg := range args { if strings.HasPrefix(arg, "-") { - return nil, fmt.Errorf("unexpected flag where argument should be") + break } + lastPositionalArg = i } - for _, arg := range args[numPositionalArgs:] { - if !strings.HasPrefix(arg, "-") { - return nil, fmt.Errorf("unexpected argument where flag should be") - } + if lastPositionalArg == 0 { + return args, "", nil } - destinationArg = args[1] + if lastPositionalArg == 1 { + return append([]string{args[0]}, args[2:]...), args[1], nil + } - return append([]string{args[0]}, args[numPositionalArgs:]...), nil + return nil, "", fmt.Errorf("unexpected positional argument(s) %v", args[2:lastPositionalArg+1]) } -func verifyFlags() error { +func verifyFlags(o *options) error { log.Infof("Verifying Flags") - if _, exists := availableRuntimes[runtimeFlag]; !exists { - return fmt.Errorf("unknown runtime: %v", runtimeFlag) + if o.root == "" { + return fmt.Errorf("the install root must be specified") + } + + if _, exists := availableRuntimes[o.runtime]; !exists { + return fmt.Errorf("unknown runtime: %v", o.runtime) } return nil } @@ -211,14 +219,14 @@ func initialize() error { return nil } -func installToolkit() error { +func installToolkit(o *options) error { log.Infof("Installing toolkit") cmdline := []string{ toolkitCommand, "install", "--toolkit-root", - filepath.Join(destinationArg, toolkitSubDir), + filepath.Join(o.root, toolkitSubDir), } cmd := exec.Command("sh", "-c", strings.Join(cmdline, " ")) @@ -232,19 +240,19 @@ func installToolkit() error { return nil } -func setupRuntime() error { - toolkitDir := filepath.Join(destinationArg, toolkitSubDir) +func setupRuntime(o *options) error { + toolkitDir := filepath.Join(o.root, toolkitSubDir) log.Infof("Setting up runtime") - cmdline := fmt.Sprintf("%v setup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) + cmdline := fmt.Sprintf("%v setup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir) cmd := exec.Command("sh", "-c", cmdline) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { - return fmt.Errorf("error running %v command: %v", runtimeFlag, err) + return fmt.Errorf("error running %v command: %v", o.runtime, err) } return nil @@ -257,19 +265,19 @@ func waitForSignal() error { return nil } -func cleanupRuntime() error { - toolkitDir := filepath.Join(destinationArg, toolkitSubDir) +func cleanupRuntime(o *options) error { + toolkitDir := filepath.Join(o.root, toolkitSubDir) log.Infof("Cleaning up Runtime") - cmdline := fmt.Sprintf("%v cleanup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) + cmdline := fmt.Sprintf("%v cleanup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir) cmd := exec.Command("sh", "-c", cmdline) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { - return fmt.Errorf("error running %v command: %v", runtimeFlag, err) + return fmt.Errorf("error running %v command: %v", o.runtime, err) } return nil diff --git a/tools/container/nvidia-toolkit/run_test.go b/tools/container/nvidia-toolkit/run_test.go new file mode 100644 index 00000000..8a0bb50e --- /dev/null +++ b/tools/container/nvidia-toolkit/run_test.go @@ -0,0 +1,84 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package main + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseArgs(t *testing.T) { + testCases := []struct { + args []string + expectedRemaining []string + expectedRoot string + expectedError error + }{ + { + args: []string{}, + expectedRemaining: []string{}, + expectedRoot: "", + expectedError: nil, + }, + { + args: []string{"app"}, + expectedRemaining: []string{"app"}, + }, + { + args: []string{"app", "root"}, + expectedRemaining: []string{"app"}, + expectedRoot: "root", + }, + { + args: []string{"app", "--flag"}, + expectedRemaining: []string{"app", "--flag"}, + }, + { + args: []string{"app", "root", "--flag"}, + expectedRemaining: []string{"app", "--flag"}, + expectedRoot: "root", + }, + { + args: []string{"app", "root", "not-root", "--flag"}, + expectedError: fmt.Errorf("unexpected positional argument(s) [not-root]"), + }, + { + args: []string{"app", "root", "not-root"}, + expectedError: fmt.Errorf("unexpected positional argument(s) [not-root]"), + }, + { + args: []string{"app", "root", "not-root", "also"}, + expectedError: fmt.Errorf("unexpected positional argument(s) [not-root also]"), + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + remaining, root, err := ParseArgs(tc.args) + if tc.expectedError != nil { + require.EqualError(t, err, tc.expectedError.Error()) + } else { + require.NoError(t, err) + } + + require.ElementsMatch(t, tc.expectedRemaining, remaining) + require.Equal(t, tc.expectedRoot, root) + }) + } +}