Merge branch 'remove-positional-arguments' into 'main'

Allow install root to be set as positional argument OR flag

See merge request nvidia/container-toolkit/container-toolkit!212
This commit is contained in:
Evan Lezar 2022-09-16 09:36:17 +00:00
commit a9fb7a4a88
2 changed files with 150 additions and 58 deletions
tools/container/nvidia-toolkit

View File

@ -30,23 +30,35 @@ var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "container
var waitingForSignal = make(chan bool, 1) var waitingForSignal = make(chan bool, 1)
var signalReceived = make(chan bool, 1) var signalReceived = make(chan bool, 1)
var destinationArg string // options stores the command line arguments
var noDaemonFlag bool type options struct {
var runtimeFlag string noDaemon bool
var runtimeArgsFlag string runtime string
runtimeArgs string
root string
}
// Version defines the CLI version. This is set at build time using LD FLAGS // Version defines the CLI version. This is set at build time using LD FLAGS
var Version = "development" var Version = "development"
func main() { func main() {
remainingArgs, root, err := ParseArgs(os.Args)
if err != nil {
log.Errorf("Error: unable to parse arguments: %v", err)
os.Exit(1)
}
options := options{}
// Create the top-level CLI // Create the top-level CLI
c := cli.NewApp() c := cli.NewApp()
c.Name = "nvidia-toolkit" c.Name = "nvidia-toolkit"
c.Usage = "Install the nvidia-container-toolkit for use by a given runtime" c.Usage = "Install the nvidia-container-toolkit for use by a given runtime"
c.UsageText = "DESTINATION [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" c.UsageText = "[DESTINATION] [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]"
c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit"
c.Version = Version c.Version = Version
c.Action = Run c.Action = func(ctx *cli.Context) error {
return Run(ctx, &options)
}
// Setup flags for the CLI // Setup flags for the CLI
c.Flags = []cli.Flag{ c.Flags = []cli.Flag{
@ -54,7 +66,7 @@ func main() {
Name: "no-daemon", Name: "no-daemon",
Aliases: []string{"n"}, Aliases: []string{"n"},
Usage: "terminate immediatly after setting up the runtime. Note that no cleanup will be performed", Usage: "terminate immediatly after setting up the runtime. Note that no cleanup will be performed",
Destination: &noDaemonFlag, Destination: &options.noDaemon,
EnvVars: []string{"NO_DAEMON"}, EnvVars: []string{"NO_DAEMON"},
}, },
&cli.StringFlag{ &cli.StringFlag{
@ -62,7 +74,7 @@ func main() {
Aliases: []string{"r"}, Aliases: []string{"r"},
Usage: "the runtime to setup on this node. One of {'docker', 'crio', 'containerd'}", Usage: "the runtime to setup on this node. One of {'docker', 'crio', 'containerd'}",
Value: defaultRuntime, Value: defaultRuntime,
Destination: &runtimeFlag, Destination: &options.runtime,
EnvVars: []string{"RUNTIME"}, EnvVars: []string{"RUNTIME"},
}, },
&cli.StringFlag{ &cli.StringFlag{
@ -70,20 +82,20 @@ func main() {
Aliases: []string{"u"}, Aliases: []string{"u"},
Usage: "arguments to pass to 'docker', 'crio', or 'containerd' setup command", Usage: "arguments to pass to 'docker', 'crio', or 'containerd' setup command",
Value: defaultRuntimeArgs, Value: defaultRuntimeArgs,
Destination: &runtimeArgsFlag, Destination: &options.runtimeArgs,
EnvVars: []string{"RUNTIME_ARGS"}, EnvVars: []string{"RUNTIME_ARGS"},
}, },
&cli.StringFlag{
Name: "root",
Value: root,
Usage: "the folder where the NVIDIA Container Toolkit is to be installed. It will be installed to `ROOT`/toolkit",
Destination: &options.root,
EnvVars: []string{"ROOT"},
},
} }
// Run the CLI // Run the CLI
log.Infof("Starting %v", c.Name) log.Infof("Starting %v", c.Name)
remainingArgs, err := ParseArgs(os.Args)
if err != nil {
log.Errorf("Error: unable to parse arguments: %v", err)
os.Exit(1)
}
if err := c.Run(remainingArgs); err != nil { if err := c.Run(remainingArgs); err != nil {
log.Errorf("error running nvidia-toolkit: %v", err) log.Errorf("error running nvidia-toolkit: %v", err)
os.Exit(1) os.Exit(1)
@ -93,8 +105,8 @@ func main() {
} }
// Run runs the core logic of the CLI // Run runs the core logic of the CLI
func Run(c *cli.Context) error { func Run(c *cli.Context, o *options) error {
err := verifyFlags() err := verifyFlags(o)
if err != nil { if err != nil {
return fmt.Errorf("unable to verify flags: %v", err) return fmt.Errorf("unable to verify flags: %v", err)
} }
@ -105,23 +117,23 @@ func Run(c *cli.Context) error {
} }
defer shutdown() defer shutdown()
err = installToolkit() err = installToolkit(o)
if err != nil { if err != nil {
return fmt.Errorf("unable to install toolkit: %v", err) return fmt.Errorf("unable to install toolkit: %v", err)
} }
err = setupRuntime() err = setupRuntime(o)
if err != nil { if err != nil {
return fmt.Errorf("unable to setup runtime: %v", err) return fmt.Errorf("unable to setup runtime: %v", err)
} }
if !noDaemonFlag { if !o.noDaemon {
err = waitForSignal() err = waitForSignal()
if err != nil { if err != nil {
return fmt.Errorf("unable to wait for signal: %v", err) return fmt.Errorf("unable to wait for signal: %v", err)
} }
err = cleanupRuntime() err = cleanupRuntime(o)
if err != nil { if err != nil {
return fmt.Errorf("unable to cleanup runtime: %v", err) return fmt.Errorf("unable to cleanup runtime: %v", err)
} }
@ -130,46 +142,42 @@ func Run(c *cli.Context) error {
return nil return nil
} }
// ParseArgs parses the command line arguments and returns the remaining arguments // ParseArgs checks if a single positional argument was defined and extracts this the root.
func ParseArgs(args []string) ([]string, error) { // If no positional arguments are defined, the it is assumed that the root is specified as a flag.
func ParseArgs(args []string) ([]string, string, error) {
log.Infof("Parsing arguments") log.Infof("Parsing arguments")
numPositionalArgs := 2 // Includes command itself if len(args) < 2 {
return args, "", nil
if len(args) < numPositionalArgs {
return nil, fmt.Errorf("missing arguments")
} }
for _, arg := range args { var lastPositionalArg int
if arg == "--help" || arg == "-h" { for i, arg := range args {
return []string{args[0], arg}, nil
}
if arg == "--version" || arg == "-v" {
return []string{args[0], arg}, nil
}
}
for _, arg := range args[:numPositionalArgs] {
if strings.HasPrefix(arg, "-") { if strings.HasPrefix(arg, "-") {
return nil, fmt.Errorf("unexpected flag where argument should be") break
} }
lastPositionalArg = i
} }
for _, arg := range args[numPositionalArgs:] { if lastPositionalArg == 0 {
if !strings.HasPrefix(arg, "-") { return args, "", nil
return nil, fmt.Errorf("unexpected argument where flag should be")
}
} }
destinationArg = args[1] if lastPositionalArg == 1 {
return append([]string{args[0]}, args[2:]...), args[1], nil
}
return append([]string{args[0]}, args[numPositionalArgs:]...), nil return nil, "", fmt.Errorf("unexpected positional argument(s) %v", args[2:lastPositionalArg+1])
} }
func verifyFlags() error { func verifyFlags(o *options) error {
log.Infof("Verifying Flags") log.Infof("Verifying Flags")
if _, exists := availableRuntimes[runtimeFlag]; !exists { if o.root == "" {
return fmt.Errorf("unknown runtime: %v", runtimeFlag) return fmt.Errorf("the install root must be specified")
}
if _, exists := availableRuntimes[o.runtime]; !exists {
return fmt.Errorf("unknown runtime: %v", o.runtime)
} }
return nil return nil
} }
@ -211,14 +219,14 @@ func initialize() error {
return nil return nil
} }
func installToolkit() error { func installToolkit(o *options) error {
log.Infof("Installing toolkit") log.Infof("Installing toolkit")
cmdline := []string{ cmdline := []string{
toolkitCommand, toolkitCommand,
"install", "install",
"--toolkit-root", "--toolkit-root",
filepath.Join(destinationArg, toolkitSubDir), filepath.Join(o.root, toolkitSubDir),
} }
cmd := exec.Command("sh", "-c", strings.Join(cmdline, " ")) cmd := exec.Command("sh", "-c", strings.Join(cmdline, " "))
@ -232,19 +240,19 @@ func installToolkit() error {
return nil return nil
} }
func setupRuntime() error { func setupRuntime(o *options) error {
toolkitDir := filepath.Join(destinationArg, toolkitSubDir) toolkitDir := filepath.Join(o.root, toolkitSubDir)
log.Infof("Setting up runtime") log.Infof("Setting up runtime")
cmdline := fmt.Sprintf("%v setup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) cmdline := fmt.Sprintf("%v setup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir)
cmd := exec.Command("sh", "-c", cmdline) cmd := exec.Command("sh", "-c", cmdline)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
err := cmd.Run() err := cmd.Run()
if err != nil { if err != nil {
return fmt.Errorf("error running %v command: %v", runtimeFlag, err) return fmt.Errorf("error running %v command: %v", o.runtime, err)
} }
return nil return nil
@ -257,19 +265,19 @@ func waitForSignal() error {
return nil return nil
} }
func cleanupRuntime() error { func cleanupRuntime(o *options) error {
toolkitDir := filepath.Join(destinationArg, toolkitSubDir) toolkitDir := filepath.Join(o.root, toolkitSubDir)
log.Infof("Cleaning up Runtime") log.Infof("Cleaning up Runtime")
cmdline := fmt.Sprintf("%v cleanup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) cmdline := fmt.Sprintf("%v cleanup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir)
cmd := exec.Command("sh", "-c", cmdline) cmd := exec.Command("sh", "-c", cmdline)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
err := cmd.Run() err := cmd.Run()
if err != nil { if err != nil {
return fmt.Errorf("error running %v command: %v", runtimeFlag, err) return fmt.Errorf("error running %v command: %v", o.runtime, err)
} }
return nil return nil

View File

@ -0,0 +1,84 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package main
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestParseArgs(t *testing.T) {
testCases := []struct {
args []string
expectedRemaining []string
expectedRoot string
expectedError error
}{
{
args: []string{},
expectedRemaining: []string{},
expectedRoot: "",
expectedError: nil,
},
{
args: []string{"app"},
expectedRemaining: []string{"app"},
},
{
args: []string{"app", "root"},
expectedRemaining: []string{"app"},
expectedRoot: "root",
},
{
args: []string{"app", "--flag"},
expectedRemaining: []string{"app", "--flag"},
},
{
args: []string{"app", "root", "--flag"},
expectedRemaining: []string{"app", "--flag"},
expectedRoot: "root",
},
{
args: []string{"app", "root", "not-root", "--flag"},
expectedError: fmt.Errorf("unexpected positional argument(s) [not-root]"),
},
{
args: []string{"app", "root", "not-root"},
expectedError: fmt.Errorf("unexpected positional argument(s) [not-root]"),
},
{
args: []string{"app", "root", "not-root", "also"},
expectedError: fmt.Errorf("unexpected positional argument(s) [not-root also]"),
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
remaining, root, err := ParseArgs(tc.args)
if tc.expectedError != nil {
require.EqualError(t, err, tc.expectedError.Error())
} else {
require.NoError(t, err)
}
require.ElementsMatch(t, tc.expectedRemaining, remaining)
require.Equal(t, tc.expectedRoot, root)
})
}
}