diff --git a/cmd/nvidia-container-runtime/nvcr.go b/cmd/nvidia-container-runtime/nvcr.go index 1eb23a26..6c4185b3 100644 --- a/cmd/nvidia-container-runtime/nvcr.go +++ b/cmd/nvidia-container-runtime/nvcr.go @@ -66,22 +66,9 @@ func (r nvidiaContainerRuntime) Exec(args []string) error { // modificationRequired checks the intput arguments to determine whether a modification // to the OCI spec is required. func (r nvidiaContainerRuntime) modificationRequired(args []string) bool { - var previousWasBundle bool - for _, a := range args { - // We check for '--bundle create' explicitly to ensure that we - // don't inadvertently trigger a modification if the bundle directory - // is specified as `create` - if !previousWasBundle && isBundleFlag(a) { - previousWasBundle = true - continue - } - - if !previousWasBundle && a == "create" { - r.logger.Infof("'create' command detected; modification required") - return true - } - - previousWasBundle = false + if oci.HasCreateSubcommand(args) { + r.logger.Infof("'create' command detected; modification required") + return true } r.logger.Infof("No modification required") diff --git a/cmd/nvidia-container-runtime/nvcr_test.go b/cmd/nvidia-container-runtime/nvcr_test.go index 32abc498..16e1b447 100644 --- a/cmd/nvidia-container-runtime/nvcr_test.go +++ b/cmd/nvidia-container-runtime/nvcr_test.go @@ -18,7 +18,6 @@ package main import ( "fmt" - "os" "strings" "testing" @@ -28,35 +27,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestArgsGetConfigFilePath(t *testing.T) { - wd, err := os.Getwd() - require.NoError(t, err) - - testCases := []struct { - bundleDir string - ociSpecPath string - }{ - { - ociSpecPath: fmt.Sprintf("%v/config.json", wd), - }, - { - bundleDir: "/foo/bar", - ociSpecPath: "/foo/bar/config.json", - }, - { - bundleDir: "/foo/bar/", - ociSpecPath: "/foo/bar/config.json", - }, - } - - for i, tc := range testCases { - cp, err := getOCISpecFilePath(tc.bundleDir) - - require.NoErrorf(t, err, "%d: %v", i, tc) - require.Equalf(t, tc.ociSpecPath, cp, "%d: %v", i, tc) - } -} - func TestAddNvidiaHook(t *testing.T) { logger, logHook := testlog.NewNullLogger() shim := nvidiaContainerRuntime{ diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go index d93b5044..0af37f6e 100644 --- a/cmd/nvidia-container-runtime/runtime_factory.go +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -18,9 +18,6 @@ package main import ( "fmt" - "os" - "path/filepath" - "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) @@ -53,15 +50,15 @@ func newRuntime(argv []string) (oci.Runtime, error) { // newOCISpec constructs an OCI spec for the provided arguments func newOCISpec(argv []string) (oci.Spec, error) { - bundlePath, err := getBundlePath(argv) + bundleDir, err := oci.GetBundleDir(argv) if err != nil { return nil, fmt.Errorf("error parsing command line arguments: %v", err) } + logger.Infof("Using bundle directory: %v", bundleDir) + + ociSpecPath := oci.GetSpecFilePath(bundleDir) + logger.Infof("Using OCI specification file path: %v", ociSpecPath) - ociSpecPath, err := getOCISpecFilePath(bundlePath) - if err != nil { - return nil, fmt.Errorf("error getting OCI specification file path: %v", err) - } ociSpec := oci.NewSpecFromFile(ociSpecPath) return ociSpec, nil @@ -75,70 +72,3 @@ func newRuncRuntime() (oci.Runtime, error) { runcExecutableName, ) } - -// getBundlePath checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc. -// The following are supported: -// --bundle{{SEP}}BUNDLE_PATH -// -bundle{{SEP}}BUNDLE_PATH -// -b{{SEP}}BUNDLE_PATH -// where {{SEP}} is either ' ' or '=' -func getBundlePath(argv []string) (string, error) { - var bundlePath string - - for i := 0; i < len(argv); i++ { - param := argv[i] - - parts := strings.SplitN(param, "=", 2) - if !isBundleFlag(parts[0]) { - continue - } - - // The flag has the format --bundle=/path - if len(parts) == 2 { - bundlePath = parts[1] - continue - } - - // The flag has the format --bundle /path - if i+1 < len(argv) { - bundlePath = argv[i+1] - i++ - continue - } - - // --bundle / -b was the last element of argv - return "", fmt.Errorf("bundle option requires an argument") - } - - return bundlePath, nil -} - -func isBundleFlag(arg string) bool { - if !strings.HasPrefix(arg, "-") { - return false - } - - trimmed := strings.TrimLeft(arg, "-") - return trimmed == "b" || trimmed == "bundle" -} - -// getOCISpecFilePath returns the expected path to the OCI specification file for the given -// bundle directory or the current working directory if not specified. -func getOCISpecFilePath(bundleDir string) (string, error) { - if bundleDir == "" { - logger.Infof("Bundle directory path is empty, using working directory.") - workingDirectory, err := os.Getwd() - if err != nil { - return "", fmt.Errorf("error getting working directory: %v", err) - } - bundleDir = workingDirectory - } - - logger.Infof("Using bundle directory: %v", bundleDir) - - OCISpecFilePath := filepath.Join(bundleDir, ociSpecFileName) - - logger.Infof("Using OCI specification file path: %v", OCISpecFilePath) - - return OCISpecFilePath, nil -} diff --git a/cmd/nvidia-container-runtime/runtime_factory_test.go b/cmd/nvidia-container-runtime/runtime_factory_test.go index 257e6aa8..07d02232 100644 --- a/cmd/nvidia-container-runtime/runtime_factory_test.go +++ b/cmd/nvidia-container-runtime/runtime_factory_test.go @@ -28,111 +28,3 @@ func TestConstructor(t *testing.T) { require.NoError(t, err) require.NotNil(t, shim) } - -func TestGetBundlePath(t *testing.T) { - type expected struct { - bundle string - isError bool - } - testCases := []struct { - argv []string - expected expected - }{ - { - argv: []string{}, - }, - { - argv: []string{"create"}, - }, - { - argv: []string{"--bundle"}, - expected: expected{ - isError: true, - }, - }, - { - argv: []string{"-b"}, - expected: expected{ - isError: true, - }, - }, - { - argv: []string{"--bundle", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"--not-bundle", "/foo/bar"}, - }, - { - argv: []string{"--"}, - }, - { - argv: []string{"-bundle", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"--bundle=/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"-b=/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"-b=/foo/=bar"}, - expected: expected{ - bundle: "/foo/=bar", - }, - }, - { - argv: []string{"-b", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"create", "-b", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"-b", "create", "create"}, - expected: expected{ - bundle: "create", - }, - }, - { - argv: []string{"-b=create", "create"}, - expected: expected{ - bundle: "create", - }, - }, - { - argv: []string{"-b", "create"}, - expected: expected{ - bundle: "create", - }, - }, - } - - for i, tc := range testCases { - bundle, err := getBundlePath(tc.argv) - - if tc.expected.isError { - require.Errorf(t, err, "%d: %v", i, tc) - } else { - require.NoErrorf(t, err, "%d: %v", i, tc) - } - - require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc) - } -}