mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-06-09 16:18:01 +00:00
Import cmd/nvidia-container-runtime from experimental branch
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
d234077780
commit
ec8a6d978d
@ -66,24 +66,11 @@ func (r nvidiaContainerRuntime) Exec(args []string) error {
|
|||||||
// modificationRequired checks the intput arguments to determine whether a modification
|
// modificationRequired checks the intput arguments to determine whether a modification
|
||||||
// to the OCI spec is required.
|
// to the OCI spec is required.
|
||||||
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
|
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
|
||||||
var previousWasBundle bool
|
if oci.HasCreateSubcommand(args) {
|
||||||
for _, a := range args {
|
|
||||||
// We check for '--bundle create' explicitly to ensure that we
|
|
||||||
// don't inadvertently trigger a modification if the bundle directory
|
|
||||||
// is specified as `create`
|
|
||||||
if !previousWasBundle && isBundleFlag(a) {
|
|
||||||
previousWasBundle = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !previousWasBundle && a == "create" {
|
|
||||||
r.logger.Infof("'create' command detected; modification required")
|
r.logger.Infof("'create' command detected; modification required")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
previousWasBundle = false
|
|
||||||
}
|
|
||||||
|
|
||||||
r.logger.Infof("No modification required")
|
r.logger.Infof("No modification required")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -27,32 +27,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestArgsGetConfigFilePath(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
bundleDir string
|
|
||||||
ociSpecPath string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
ociSpecPath: "config.json",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
bundleDir: "/foo/bar",
|
|
||||||
ociSpecPath: "/foo/bar/config.json",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
bundleDir: "/foo/bar/",
|
|
||||||
ociSpecPath: "/foo/bar/config.json",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range testCases {
|
|
||||||
cp, err := getOCISpecFilePath(tc.bundleDir)
|
|
||||||
|
|
||||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
|
||||||
require.Equalf(t, tc.ociSpecPath, cp, "%d: %v", i, tc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddNvidiaHook(t *testing.T) {
|
func TestAddNvidiaHook(t *testing.T) {
|
||||||
logger, logHook := testlog.NewNullLogger()
|
logger, logHook := testlog.NewNullLogger()
|
||||||
shim := nvidiaContainerRuntime{
|
shim := nvidiaContainerRuntime{
|
||||||
@ -181,9 +155,14 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
|||||||
tc.shim.logger = logger
|
tc.shim.logger = logger
|
||||||
hook.Reset()
|
hook.Reset()
|
||||||
|
|
||||||
spec := &specs.Spec{}
|
ociMock := &oci.SpecMock{
|
||||||
ociMock := oci.NewMockSpec(spec, tc.writeError, tc.modifyError)
|
ModifyFunc: func(specModifier oci.SpecModifier) error {
|
||||||
|
return tc.modifyError
|
||||||
|
},
|
||||||
|
FlushFunc: func() error {
|
||||||
|
return tc.writeError
|
||||||
|
},
|
||||||
|
}
|
||||||
require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)
|
require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)
|
||||||
|
|
||||||
tc.shim.ociSpec = ociMock
|
tc.shim.ociSpec = ociMock
|
||||||
@ -197,18 +176,16 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tc.shouldModify {
|
if tc.shouldModify {
|
||||||
require.Equal(t, 1, ociMock.MockModify.Callcount, "%d: %v", i, tc)
|
require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
|
||||||
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "%d: %v", i, tc)
|
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, 0, ociMock.MockModify.Callcount, "%d: %v", i, tc)
|
require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
|
||||||
require.Nil(t, spec.Hooks, "%d: %v", i, tc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
writeExpected := tc.shouldModify && tc.modifyError == nil
|
writeExpected := tc.shouldModify && tc.modifyError == nil
|
||||||
if writeExpected {
|
if writeExpected {
|
||||||
require.Equal(t, 1, ociMock.MockFlush.Callcount, "%d: %v", i, tc)
|
require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc)
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, 0, ociMock.MockFlush.Callcount, "%d: %v", i, tc)
|
require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,9 +18,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||||
)
|
)
|
||||||
@ -53,15 +50,15 @@ func newRuntime(argv []string) (oci.Runtime, error) {
|
|||||||
|
|
||||||
// newOCISpec constructs an OCI spec for the provided arguments
|
// newOCISpec constructs an OCI spec for the provided arguments
|
||||||
func newOCISpec(argv []string) (oci.Spec, error) {
|
func newOCISpec(argv []string) (oci.Spec, error) {
|
||||||
bundlePath, err := getBundlePath(argv)
|
bundleDir, err := oci.GetBundleDir(argv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
|
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
|
||||||
}
|
}
|
||||||
|
logger.Infof("Using bundle directory: %v", bundleDir)
|
||||||
|
|
||||||
|
ociSpecPath := oci.GetSpecFilePath(bundleDir)
|
||||||
|
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
|
||||||
|
|
||||||
ociSpecPath, err := getOCISpecFilePath(bundlePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error getting OCI specification file path: %v", err)
|
|
||||||
}
|
|
||||||
ociSpec := oci.NewSpecFromFile(ociSpecPath)
|
ociSpec := oci.NewSpecFromFile(ociSpecPath)
|
||||||
|
|
||||||
return ociSpec, nil
|
return ociSpec, nil
|
||||||
@ -69,98 +66,9 @@ func newOCISpec(argv []string) (oci.Spec, error) {
|
|||||||
|
|
||||||
// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
|
// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
|
||||||
func newRuncRuntime() (oci.Runtime, error) {
|
func newRuncRuntime() (oci.Runtime, error) {
|
||||||
runtimePath, err := findRunc()
|
return oci.NewLowLevelRuntimeWithLogger(
|
||||||
if err != nil {
|
logger.Logger,
|
||||||
return nil, fmt.Errorf("error locating runtime: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
runc, err := oci.NewSyscallExecRuntimeWithLogger(logger.Logger, runtimePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error constructing runtime: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return runc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getBundlePath checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
|
|
||||||
// The following are supported:
|
|
||||||
// --bundle{{SEP}}BUNDLE_PATH
|
|
||||||
// -bundle{{SEP}}BUNDLE_PATH
|
|
||||||
// -b{{SEP}}BUNDLE_PATH
|
|
||||||
// where {{SEP}} is either ' ' or '='
|
|
||||||
func getBundlePath(argv []string) (string, error) {
|
|
||||||
var bundlePath string
|
|
||||||
|
|
||||||
for i := 0; i < len(argv); i++ {
|
|
||||||
param := argv[i]
|
|
||||||
|
|
||||||
parts := strings.SplitN(param, "=", 2)
|
|
||||||
if !isBundleFlag(parts[0]) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// The flag has the format --bundle=/path
|
|
||||||
if len(parts) == 2 {
|
|
||||||
bundlePath = parts[1]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// The flag has the format --bundle /path
|
|
||||||
if i+1 < len(argv) {
|
|
||||||
bundlePath = argv[i+1]
|
|
||||||
i++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// --bundle / -b was the last element of argv
|
|
||||||
return "", fmt.Errorf("bundle option requires an argument")
|
|
||||||
}
|
|
||||||
|
|
||||||
return bundlePath, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// findRunc locates runc in the path, returning the full path to the
|
|
||||||
// binary or an error.
|
|
||||||
func findRunc() (string, error) {
|
|
||||||
runtimeCandidates := []string{
|
|
||||||
dockerRuncExecutableName,
|
dockerRuncExecutableName,
|
||||||
runcExecutableName,
|
runcExecutableName,
|
||||||
}
|
)
|
||||||
|
|
||||||
return findRuntime(runtimeCandidates)
|
|
||||||
}
|
|
||||||
|
|
||||||
func findRuntime(runtimeCandidates []string) (string, error) {
|
|
||||||
for _, candidate := range runtimeCandidates {
|
|
||||||
logger.Infof("Looking for runtime binary '%v'", candidate)
|
|
||||||
runcPath, err := exec.LookPath(candidate)
|
|
||||||
if err == nil {
|
|
||||||
logger.Infof("Found runtime binary '%v'", runcPath)
|
|
||||||
return runcPath, nil
|
|
||||||
}
|
|
||||||
logger.Warnf("Runtime binary '%v' not found: %v", candidate, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("no runtime binary found from candidate list: %v", runtimeCandidates)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isBundleFlag(arg string) bool {
|
|
||||||
if !strings.HasPrefix(arg, "-") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
trimmed := strings.TrimLeft(arg, "-")
|
|
||||||
return trimmed == "b" || trimmed == "bundle"
|
|
||||||
}
|
|
||||||
|
|
||||||
// getOCISpecFilePath returns the expected path to the OCI specification file for the given
|
|
||||||
// bundle directory. If the bundle directory is empty, only `config.json` is returned.
|
|
||||||
func getOCISpecFilePath(bundleDir string) (string, error) {
|
|
||||||
logger.Infof("Using bundle directory: %v", bundleDir)
|
|
||||||
|
|
||||||
OCISpecFilePath := filepath.Join(bundleDir, ociSpecFileName)
|
|
||||||
|
|
||||||
logger.Infof("Using OCI specification file path: %v", OCISpecFilePath)
|
|
||||||
|
|
||||||
return OCISpecFilePath, nil
|
|
||||||
}
|
}
|
||||||
|
@ -17,10 +17,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,163 +28,3 @@ func TestConstructor(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, shim)
|
require.NotNil(t, shim)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetBundlePath(t *testing.T) {
|
|
||||||
type expected struct {
|
|
||||||
bundle string
|
|
||||||
isError bool
|
|
||||||
}
|
|
||||||
testCases := []struct {
|
|
||||||
argv []string
|
|
||||||
expected expected
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
argv: []string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"create"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"--bundle"},
|
|
||||||
expected: expected{
|
|
||||||
isError: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-b"},
|
|
||||||
expected: expected{
|
|
||||||
isError: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"--bundle", "/foo/bar"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "/foo/bar",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"--not-bundle", "/foo/bar"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"--"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-bundle", "/foo/bar"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "/foo/bar",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"--bundle=/foo/bar"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "/foo/bar",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-b=/foo/bar"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "/foo/bar",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-b=/foo/=bar"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "/foo/=bar",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-b", "/foo/bar"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "/foo/bar",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"create", "-b", "/foo/bar"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "/foo/bar",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-b", "create", "create"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "create",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-b=create", "create"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "create",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
argv: []string{"-b", "create"},
|
|
||||||
expected: expected{
|
|
||||||
bundle: "create",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range testCases {
|
|
||||||
bundle, err := getBundlePath(tc.argv)
|
|
||||||
|
|
||||||
if tc.expected.isError {
|
|
||||||
require.Errorf(t, err, "%d: %v", i, tc)
|
|
||||||
} else {
|
|
||||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindRunc(t *testing.T) {
|
|
||||||
testLogger, _ := testlog.NewNullLogger()
|
|
||||||
logger.Logger = testLogger
|
|
||||||
|
|
||||||
runcPath, err := findRunc()
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, filepath.Join(cfg.binPath, runcExecutableName), runcPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindRuntime(t *testing.T) {
|
|
||||||
testLogger, _ := testlog.NewNullLogger()
|
|
||||||
logger.Logger = testLogger
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
candidates []string
|
|
||||||
expectedPath string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
candidates: []string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
candidates: []string{"not-runc"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
candidates: []string{"not-runc", "also-not-runc"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
candidates: []string{runcExecutableName},
|
|
||||||
expectedPath: filepath.Join(cfg.binPath, runcExecutableName),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
candidates: []string{runcExecutableName, "not-runc"},
|
|
||||||
expectedPath: filepath.Join(cfg.binPath, runcExecutableName),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
candidates: []string{"not-runc", runcExecutableName},
|
|
||||||
expectedPath: filepath.Join(cfg.binPath, runcExecutableName),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range testCases {
|
|
||||||
runcPath, err := findRuntime(tc.candidates)
|
|
||||||
if tc.expectedPath == "" {
|
|
||||||
require.Error(t, err, "%d: %v", i, tc)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err, "%d: %v", i, tc)
|
|
||||||
}
|
|
||||||
require.Equal(t, tc.expectedPath, runcPath, "%d: %v", i, tc)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user