diff --git a/tools/container/nvidia-toolkit/run.go b/tools/container/nvidia-toolkit/run.go index c184ba3c..aff9dd40 100644 --- a/tools/container/nvidia-toolkit/run.go +++ b/tools/container/nvidia-toolkit/run.go @@ -35,22 +35,25 @@ type options struct { noDaemon bool runtime string runtimeArgs string - - root string + root string } // Version defines the CLI version. This is set at build time using LD FLAGS var Version = "development" func main() { + remainingArgs, root, err := ParseArgs(os.Args) + if err != nil { + log.Errorf("Error: unable to parse arguments: %v", err) + os.Exit(1) + } options := options{} - // Create the top-level CLI c := cli.NewApp() c.Name = "nvidia-toolkit" c.Usage = "Install the nvidia-container-toolkit for use by a given runtime" - c.UsageText = "DESTINATION [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" + c.UsageText = "[DESTINATION] [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Version = Version c.Action = func(ctx *cli.Context) error { @@ -82,17 +85,17 @@ func main() { Destination: &options.runtimeArgs, EnvVars: []string{"RUNTIME_ARGS"}, }, + &cli.StringFlag{ + Name: "root", + Value: root, + Usage: "the folder where the NVIDIA Container Toolkit is to be installed. It will be installed to `ROOT`/toolkit", + Destination: &options.root, + EnvVars: []string{"ROOT"}, + }, } // Run the CLI log.Infof("Starting %v", c.Name) - - remainingArgs, err := ParseArgs(os.Args, &options) - if err != nil { - log.Errorf("Error: unable to parse arguments: %v", err) - os.Exit(1) - } - if err := c.Run(remainingArgs); err != nil { log.Errorf("error running nvidia-toolkit: %v", err) os.Exit(1) @@ -139,44 +142,40 @@ func Run(c *cli.Context, o *options) error { return nil } -// ParseArgs parses the command line arguments and returns the remaining arguments -func ParseArgs(args []string, o *options) ([]string, error) { +// ParseArgs checks if a single positional argument was defined and extracts this the root. +// If no positional arguments are defined, the it is assumed that the root is specified as a flag. +func ParseArgs(args []string) ([]string, string, error) { log.Infof("Parsing arguments") - numPositionalArgs := 2 // Includes command itself - - if len(args) < numPositionalArgs { - return nil, fmt.Errorf("missing arguments") + if len(args) < 2 { + return args, "", nil } - for _, arg := range args { - if arg == "--help" || arg == "-h" { - return []string{args[0], arg}, nil - } - if arg == "--version" || arg == "-v" { - return []string{args[0], arg}, nil - } - } - - for _, arg := range args[:numPositionalArgs] { + var lastPositionalArg int + for i, arg := range args { if strings.HasPrefix(arg, "-") { - return nil, fmt.Errorf("unexpected flag where argument should be") + break } + lastPositionalArg = i } - for _, arg := range args[numPositionalArgs:] { - if !strings.HasPrefix(arg, "-") { - return nil, fmt.Errorf("unexpected argument where flag should be") - } + if lastPositionalArg == 0 { + return args, "", nil } - o.root = args[1] + if lastPositionalArg == 1 { + return append([]string{args[0]}, args[2:]...), args[1], nil + } - return append([]string{args[0]}, args[numPositionalArgs:]...), nil + return nil, "", fmt.Errorf("unexpected positional argument(s) %v", args[2:lastPositionalArg+1]) } func verifyFlags(o *options) error { log.Infof("Verifying Flags") + if o.root == "" { + return fmt.Errorf("the install root must be specified") + } + if _, exists := availableRuntimes[o.runtime]; !exists { return fmt.Errorf("unknown runtime: %v", o.runtime) } diff --git a/tools/container/nvidia-toolkit/run_test.go b/tools/container/nvidia-toolkit/run_test.go new file mode 100644 index 00000000..8a0bb50e --- /dev/null +++ b/tools/container/nvidia-toolkit/run_test.go @@ -0,0 +1,84 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package main + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseArgs(t *testing.T) { + testCases := []struct { + args []string + expectedRemaining []string + expectedRoot string + expectedError error + }{ + { + args: []string{}, + expectedRemaining: []string{}, + expectedRoot: "", + expectedError: nil, + }, + { + args: []string{"app"}, + expectedRemaining: []string{"app"}, + }, + { + args: []string{"app", "root"}, + expectedRemaining: []string{"app"}, + expectedRoot: "root", + }, + { + args: []string{"app", "--flag"}, + expectedRemaining: []string{"app", "--flag"}, + }, + { + args: []string{"app", "root", "--flag"}, + expectedRemaining: []string{"app", "--flag"}, + expectedRoot: "root", + }, + { + args: []string{"app", "root", "not-root", "--flag"}, + expectedError: fmt.Errorf("unexpected positional argument(s) [not-root]"), + }, + { + args: []string{"app", "root", "not-root"}, + expectedError: fmt.Errorf("unexpected positional argument(s) [not-root]"), + }, + { + args: []string{"app", "root", "not-root", "also"}, + expectedError: fmt.Errorf("unexpected positional argument(s) [not-root also]"), + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + remaining, root, err := ParseArgs(tc.args) + if tc.expectedError != nil { + require.EqualError(t, err, tc.expectedError.Error()) + } else { + require.NoError(t, err) + } + + require.ElementsMatch(t, tc.expectedRemaining, remaining) + require.Equal(t, tc.expectedRoot, root) + }) + } +}