mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Import cmd/nvidia-container-runtime from experimental branch
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
		
							parent
							
								
									d234077780
								
							
						
					
					
						commit
						ec8a6d978d
					
				| @ -66,24 +66,11 @@ func (r nvidiaContainerRuntime) Exec(args []string) error { | ||||
| // modificationRequired checks the intput arguments to determine whether a modification
 | ||||
| // to the OCI spec is required.
 | ||||
| func (r nvidiaContainerRuntime) modificationRequired(args []string) bool { | ||||
| 	var previousWasBundle bool | ||||
| 	for _, a := range args { | ||||
| 		// We check for '--bundle create' explicitly to ensure that we
 | ||||
| 		// don't inadvertently trigger a modification if the bundle directory
 | ||||
| 		// is specified as `create`
 | ||||
| 		if !previousWasBundle && isBundleFlag(a) { | ||||
| 			previousWasBundle = true | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if !previousWasBundle && a == "create" { | ||||
| 	if oci.HasCreateSubcommand(args) { | ||||
| 		r.logger.Infof("'create' command detected; modification required") | ||||
| 		return true | ||||
| 	} | ||||
| 
 | ||||
| 		previousWasBundle = false | ||||
| 	} | ||||
| 
 | ||||
| 	r.logger.Infof("No modification required") | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| @ -27,32 +27,6 @@ import ( | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestArgsGetConfigFilePath(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		bundleDir   string | ||||
| 		ociSpecPath string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			ociSpecPath: "config.json", | ||||
| 		}, | ||||
| 		{ | ||||
| 			bundleDir:   "/foo/bar", | ||||
| 			ociSpecPath: "/foo/bar/config.json", | ||||
| 		}, | ||||
| 		{ | ||||
| 			bundleDir:   "/foo/bar/", | ||||
| 			ociSpecPath: "/foo/bar/config.json", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, tc := range testCases { | ||||
| 		cp, err := getOCISpecFilePath(tc.bundleDir) | ||||
| 
 | ||||
| 		require.NoErrorf(t, err, "%d: %v", i, tc) | ||||
| 		require.Equalf(t, tc.ociSpecPath, cp, "%d: %v", i, tc) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestAddNvidiaHook(t *testing.T) { | ||||
| 	logger, logHook := testlog.NewNullLogger() | ||||
| 	shim := nvidiaContainerRuntime{ | ||||
| @ -181,9 +155,14 @@ func TestNvidiaContainerRuntime(t *testing.T) { | ||||
| 		tc.shim.logger = logger | ||||
| 		hook.Reset() | ||||
| 
 | ||||
| 		spec := &specs.Spec{} | ||||
| 		ociMock := oci.NewMockSpec(spec, tc.writeError, tc.modifyError) | ||||
| 
 | ||||
| 		ociMock := &oci.SpecMock{ | ||||
| 			ModifyFunc: func(specModifier oci.SpecModifier) error { | ||||
| 				return tc.modifyError | ||||
| 			}, | ||||
| 			FlushFunc: func() error { | ||||
| 				return tc.writeError | ||||
| 			}, | ||||
| 		} | ||||
| 		require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc) | ||||
| 
 | ||||
| 		tc.shim.ociSpec = ociMock | ||||
| @ -197,18 +176,16 @@ func TestNvidiaContainerRuntime(t *testing.T) { | ||||
| 		} | ||||
| 
 | ||||
| 		if tc.shouldModify { | ||||
| 			require.Equal(t, 1, ociMock.MockModify.Callcount, "%d: %v", i, tc) | ||||
| 			require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "%d: %v", i, tc) | ||||
| 			require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc) | ||||
| 		} else { | ||||
| 			require.Equal(t, 0, ociMock.MockModify.Callcount, "%d: %v", i, tc) | ||||
| 			require.Nil(t, spec.Hooks, "%d: %v", i, tc) | ||||
| 			require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc) | ||||
| 		} | ||||
| 
 | ||||
| 		writeExpected := tc.shouldModify && tc.modifyError == nil | ||||
| 		if writeExpected { | ||||
| 			require.Equal(t, 1, ociMock.MockFlush.Callcount, "%d: %v", i, tc) | ||||
| 			require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc) | ||||
| 		} else { | ||||
| 			require.Equal(t, 0, ociMock.MockFlush.Callcount, "%d: %v", i, tc) | ||||
| 			require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -18,9 +18,6 @@ package main | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os/exec" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" | ||||
| ) | ||||
| @ -53,15 +50,15 @@ func newRuntime(argv []string) (oci.Runtime, error) { | ||||
| 
 | ||||
| // newOCISpec constructs an OCI spec for the provided arguments
 | ||||
| func newOCISpec(argv []string) (oci.Spec, error) { | ||||
| 	bundlePath, err := getBundlePath(argv) | ||||
| 	bundleDir, err := oci.GetBundleDir(argv) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error parsing command line arguments: %v", err) | ||||
| 	} | ||||
| 	logger.Infof("Using bundle directory: %v", bundleDir) | ||||
| 
 | ||||
| 	ociSpecPath := oci.GetSpecFilePath(bundleDir) | ||||
| 	logger.Infof("Using OCI specification file path: %v", ociSpecPath) | ||||
| 
 | ||||
| 	ociSpecPath, err := getOCISpecFilePath(bundlePath) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error getting OCI specification file path: %v", err) | ||||
| 	} | ||||
| 	ociSpec := oci.NewSpecFromFile(ociSpecPath) | ||||
| 
 | ||||
| 	return ociSpec, nil | ||||
| @ -69,98 +66,9 @@ func newOCISpec(argv []string) (oci.Spec, error) { | ||||
| 
 | ||||
| // newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
 | ||||
| func newRuncRuntime() (oci.Runtime, error) { | ||||
| 	runtimePath, err := findRunc() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error locating runtime: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	runc, err := oci.NewSyscallExecRuntimeWithLogger(logger.Logger, runtimePath) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error constructing runtime: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return runc, nil | ||||
| } | ||||
| 
 | ||||
| // getBundlePath checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
 | ||||
| // The following are supported:
 | ||||
| // --bundle{{SEP}}BUNDLE_PATH
 | ||||
| // -bundle{{SEP}}BUNDLE_PATH
 | ||||
| // -b{{SEP}}BUNDLE_PATH
 | ||||
| // where {{SEP}} is either ' ' or '='
 | ||||
| func getBundlePath(argv []string) (string, error) { | ||||
| 	var bundlePath string | ||||
| 
 | ||||
| 	for i := 0; i < len(argv); i++ { | ||||
| 		param := argv[i] | ||||
| 
 | ||||
| 		parts := strings.SplitN(param, "=", 2) | ||||
| 		if !isBundleFlag(parts[0]) { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// The flag has the format --bundle=/path
 | ||||
| 		if len(parts) == 2 { | ||||
| 			bundlePath = parts[1] | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// The flag has the format --bundle /path
 | ||||
| 		if i+1 < len(argv) { | ||||
| 			bundlePath = argv[i+1] | ||||
| 			i++ | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// --bundle / -b was the last element of argv
 | ||||
| 		return "", fmt.Errorf("bundle option requires an argument") | ||||
| 	} | ||||
| 
 | ||||
| 	return bundlePath, nil | ||||
| } | ||||
| 
 | ||||
| // findRunc locates runc in the path, returning the full path to the
 | ||||
| // binary or an error.
 | ||||
| func findRunc() (string, error) { | ||||
| 	runtimeCandidates := []string{ | ||||
| 	return oci.NewLowLevelRuntimeWithLogger( | ||||
| 		logger.Logger, | ||||
| 		dockerRuncExecutableName, | ||||
| 		runcExecutableName, | ||||
| 	} | ||||
| 
 | ||||
| 	return findRuntime(runtimeCandidates) | ||||
| } | ||||
| 
 | ||||
| func findRuntime(runtimeCandidates []string) (string, error) { | ||||
| 	for _, candidate := range runtimeCandidates { | ||||
| 		logger.Infof("Looking for runtime binary '%v'", candidate) | ||||
| 		runcPath, err := exec.LookPath(candidate) | ||||
| 		if err == nil { | ||||
| 			logger.Infof("Found runtime binary '%v'", runcPath) | ||||
| 			return runcPath, nil | ||||
| 		} | ||||
| 		logger.Warnf("Runtime binary '%v' not found: %v", candidate, err) | ||||
| 	} | ||||
| 
 | ||||
| 	return "", fmt.Errorf("no runtime binary found from candidate list: %v", runtimeCandidates) | ||||
| } | ||||
| 
 | ||||
| func isBundleFlag(arg string) bool { | ||||
| 	if !strings.HasPrefix(arg, "-") { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	trimmed := strings.TrimLeft(arg, "-") | ||||
| 	return trimmed == "b" || trimmed == "bundle" | ||||
| } | ||||
| 
 | ||||
| // getOCISpecFilePath returns the expected path to the OCI specification file for the given
 | ||||
| // bundle directory. If the bundle directory is empty, only `config.json` is returned.
 | ||||
| func getOCISpecFilePath(bundleDir string) (string, error) { | ||||
| 	logger.Infof("Using bundle directory: %v", bundleDir) | ||||
| 
 | ||||
| 	OCISpecFilePath := filepath.Join(bundleDir, ociSpecFileName) | ||||
| 
 | ||||
| 	logger.Infof("Using OCI specification file path: %v", OCISpecFilePath) | ||||
| 
 | ||||
| 	return OCISpecFilePath, nil | ||||
| 	) | ||||
| } | ||||
|  | ||||
| @ -17,10 +17,8 @@ | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	testlog "github.com/sirupsen/logrus/hooks/test" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| @ -30,163 +28,3 @@ func TestConstructor(t *testing.T) { | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, shim) | ||||
| } | ||||
| 
 | ||||
| func TestGetBundlePath(t *testing.T) { | ||||
| 	type expected struct { | ||||
| 		bundle  string | ||||
| 		isError bool | ||||
| 	} | ||||
| 	testCases := []struct { | ||||
| 		argv     []string | ||||
| 		expected expected | ||||
| 	}{ | ||||
| 		{ | ||||
| 			argv: []string{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"create"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"--bundle"}, | ||||
| 			expected: expected{ | ||||
| 				isError: true, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-b"}, | ||||
| 			expected: expected{ | ||||
| 				isError: true, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"--bundle", "/foo/bar"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "/foo/bar", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"--not-bundle", "/foo/bar"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"--"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-bundle", "/foo/bar"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "/foo/bar", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"--bundle=/foo/bar"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "/foo/bar", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-b=/foo/bar"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "/foo/bar", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-b=/foo/=bar"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "/foo/=bar", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-b", "/foo/bar"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "/foo/bar", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"create", "-b", "/foo/bar"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "/foo/bar", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-b", "create", "create"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "create", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-b=create", "create"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "create", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			argv: []string{"-b", "create"}, | ||||
| 			expected: expected{ | ||||
| 				bundle: "create", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, tc := range testCases { | ||||
| 		bundle, err := getBundlePath(tc.argv) | ||||
| 
 | ||||
| 		if tc.expected.isError { | ||||
| 			require.Errorf(t, err, "%d: %v", i, tc) | ||||
| 		} else { | ||||
| 			require.NoErrorf(t, err, "%d: %v", i, tc) | ||||
| 		} | ||||
| 
 | ||||
| 		require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFindRunc(t *testing.T) { | ||||
| 	testLogger, _ := testlog.NewNullLogger() | ||||
| 	logger.Logger = testLogger | ||||
| 
 | ||||
| 	runcPath, err := findRunc() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, filepath.Join(cfg.binPath, runcExecutableName), runcPath) | ||||
| } | ||||
| 
 | ||||
| func TestFindRuntime(t *testing.T) { | ||||
| 	testLogger, _ := testlog.NewNullLogger() | ||||
| 	logger.Logger = testLogger | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		candidates   []string | ||||
| 		expectedPath string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			candidates: []string{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			candidates: []string{"not-runc"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			candidates: []string{"not-runc", "also-not-runc"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			candidates:   []string{runcExecutableName}, | ||||
| 			expectedPath: filepath.Join(cfg.binPath, runcExecutableName), | ||||
| 		}, | ||||
| 		{ | ||||
| 			candidates:   []string{runcExecutableName, "not-runc"}, | ||||
| 			expectedPath: filepath.Join(cfg.binPath, runcExecutableName), | ||||
| 		}, | ||||
| 		{ | ||||
| 			candidates:   []string{"not-runc", runcExecutableName}, | ||||
| 			expectedPath: filepath.Join(cfg.binPath, runcExecutableName), | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, tc := range testCases { | ||||
| 		runcPath, err := findRuntime(tc.candidates) | ||||
| 		if tc.expectedPath == "" { | ||||
| 			require.Error(t, err, "%d: %v", i, tc) | ||||
| 		} else { | ||||
| 			require.NoError(t, err, "%d: %v", i, tc) | ||||
| 		} | ||||
| 		require.Equal(t, tc.expectedPath, runcPath, "%d: %v", i, tc) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user