mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-25 13:35:00 +00:00
Return unmodified runtime if specModifier is nil
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
99baea9d51
commit
67602b28f9
@ -48,10 +48,6 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
|
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
|
||||||
}
|
}
|
||||||
if specModifier == nil {
|
|
||||||
logger.Infof("Using low-level runtime with no modification")
|
|
||||||
return lowLevelRuntime, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the wrapping runtime with the specified modifier
|
// Create the wrapping runtime with the specified modifier
|
||||||
r := runtime.NewModifyingRuntimeWrapper(
|
r := runtime.NewModifyingRuntimeWrapper(
|
||||||
|
@ -33,8 +33,13 @@ type modifyingRuntimeWrapper struct {
|
|||||||
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
|
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
|
||||||
|
|
||||||
// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
|
// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
|
||||||
// before invoking the wrapped runtime.
|
// before invoking the wrapped runtime. If the modifier is nil, the input runtime is returned.
|
||||||
func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime {
|
func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime {
|
||||||
|
if modifier == nil {
|
||||||
|
logger.Infof("Using low-level runtime with no modification")
|
||||||
|
return runtime
|
||||||
|
}
|
||||||
|
|
||||||
rt := modifyingRuntimeWrapper{
|
rt := modifyingRuntimeWrapper{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||||
|
"github.com/opencontainers/runtime-spec/specs-go"
|
||||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -30,41 +31,60 @@ func TestExec(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
description string
|
description string
|
||||||
|
shouldLoad bool
|
||||||
shouldModify bool
|
shouldModify bool
|
||||||
shouldFlush bool
|
shouldFlush bool
|
||||||
shouldForward bool
|
shouldForward bool
|
||||||
args []string
|
args []string
|
||||||
modifyError error
|
modifyError error
|
||||||
writeError error
|
writeError error
|
||||||
|
modifer oci.SpecModifier
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
description: "no args forwards",
|
description: "no args forwards",
|
||||||
|
shouldLoad: false,
|
||||||
shouldModify: false,
|
shouldModify: false,
|
||||||
shouldFlush: false,
|
shouldFlush: false,
|
||||||
shouldForward: true,
|
shouldForward: true,
|
||||||
|
modifer: &modiferMock{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "create modifies",
|
description: "create modifies",
|
||||||
args: []string{"create"},
|
args: []string{"create"},
|
||||||
|
shouldLoad: true,
|
||||||
shouldModify: true,
|
shouldModify: true,
|
||||||
shouldFlush: true,
|
shouldFlush: true,
|
||||||
shouldForward: true,
|
shouldForward: true,
|
||||||
|
modifer: &modiferMock{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "modify error does not write or forward",
|
description: "modify error does not write or forward",
|
||||||
args: []string{"create"},
|
args: []string{"create"},
|
||||||
modifyError: fmt.Errorf("error modifying"),
|
modifyError: fmt.Errorf("error modifying"),
|
||||||
|
shouldLoad: true,
|
||||||
shouldModify: true,
|
shouldModify: true,
|
||||||
shouldFlush: false,
|
shouldFlush: false,
|
||||||
shouldForward: false,
|
shouldForward: false,
|
||||||
|
modifer: &modiferMock{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "write error does not forward",
|
description: "write error does not forward",
|
||||||
args: []string{"create"},
|
args: []string{"create"},
|
||||||
writeError: fmt.Errorf("error writing"),
|
writeError: fmt.Errorf("error writing"),
|
||||||
|
shouldLoad: true,
|
||||||
shouldModify: true,
|
shouldModify: true,
|
||||||
shouldFlush: true,
|
shouldFlush: true,
|
||||||
shouldForward: false,
|
shouldForward: false,
|
||||||
|
modifer: &modiferMock{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "nil modifier forwards on create",
|
||||||
|
args: []string{"create"},
|
||||||
|
shouldLoad: false,
|
||||||
|
shouldModify: false,
|
||||||
|
shouldFlush: false,
|
||||||
|
shouldForward: true,
|
||||||
|
modifer: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +107,8 @@ func TestExec(t *testing.T) {
|
|||||||
runtimeMock,
|
runtimeMock,
|
||||||
specMock,
|
specMock,
|
||||||
// TODO: We should test the interactions with the SpecModifier too
|
// TODO: We should test the interactions with the SpecModifier too
|
||||||
nil)
|
tc.modifer,
|
||||||
|
)
|
||||||
|
|
||||||
err := shim.Exec(tc.args)
|
err := shim.Exec(tc.args)
|
||||||
if tc.modifyError != nil || tc.writeError != nil {
|
if tc.modifyError != nil || tc.writeError != nil {
|
||||||
@ -96,6 +117,11 @@ func TestExec(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tc.shouldLoad {
|
||||||
|
require.Equal(t, 1, len(specMock.LoadCalls()))
|
||||||
|
} else {
|
||||||
|
require.Equal(t, 0, len(specMock.LoadCalls()))
|
||||||
|
}
|
||||||
if tc.shouldModify {
|
if tc.shouldModify {
|
||||||
require.Equal(t, 1, len(specMock.ModifyCalls()))
|
require.Equal(t, 1, len(specMock.ModifyCalls()))
|
||||||
} else {
|
} else {
|
||||||
@ -114,3 +140,25 @@ func TestExec(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNilModiferReturnsRuntime(t *testing.T) {
|
||||||
|
logger, _ := testlog.NewNullLogger()
|
||||||
|
|
||||||
|
runtimeMock := &oci.RuntimeMock{}
|
||||||
|
specMock := &oci.SpecMock{}
|
||||||
|
|
||||||
|
shim := NewModifyingRuntimeWrapper(
|
||||||
|
logger,
|
||||||
|
runtimeMock,
|
||||||
|
specMock,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, runtimeMock, shim)
|
||||||
|
}
|
||||||
|
|
||||||
|
type modiferMock struct{}
|
||||||
|
|
||||||
|
func (m modiferMock) Modify(*specs.Spec) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user