diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go index 20025e45..43d03769 100644 --- a/cmd/nvidia-container-runtime/runtime_factory.go +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -48,10 +48,6 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [ if err != nil { return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) } - if specModifier == nil { - logger.Infof("Using low-level runtime with no modification") - return lowLevelRuntime, nil - } // Create the wrapping runtime with the specified modifier r := runtime.NewModifyingRuntimeWrapper( diff --git a/internal/runtime/runtime_modifier.go b/internal/runtime/runtime_modifier.go index 86accba0..f385d9c3 100644 --- a/internal/runtime/runtime_modifier.go +++ b/internal/runtime/runtime_modifier.go @@ -33,8 +33,13 @@ type modifyingRuntimeWrapper struct { var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil) // NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification -// before invoking the wrapped runtime. +// before invoking the wrapped runtime. If the modifier is nil, the input runtime is returned. func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime { + if modifier == nil { + logger.Infof("Using low-level runtime with no modification") + return runtime + } + rt := modifyingRuntimeWrapper{ logger: logger, runtime: runtime, diff --git a/internal/runtime/runtime_modifier_test.go b/internal/runtime/runtime_modifier_test.go index 5ab77a04..bd07bf24 100644 --- a/internal/runtime/runtime_modifier_test.go +++ b/internal/runtime/runtime_modifier_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/opencontainers/runtime-spec/specs-go" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) @@ -30,41 +31,60 @@ func TestExec(t *testing.T) { testCases := []struct { description string + shouldLoad bool shouldModify bool shouldFlush bool shouldForward bool args []string modifyError error writeError error + modifer oci.SpecModifier }{ { description: "no args forwards", + shouldLoad: false, shouldModify: false, shouldFlush: false, shouldForward: true, + modifer: &modiferMock{}, }, { description: "create modifies", args: []string{"create"}, + shouldLoad: true, shouldModify: true, shouldFlush: true, shouldForward: true, + modifer: &modiferMock{}, }, { description: "modify error does not write or forward", args: []string{"create"}, modifyError: fmt.Errorf("error modifying"), + shouldLoad: true, shouldModify: true, shouldFlush: false, shouldForward: false, + modifer: &modiferMock{}, }, { description: "write error does not forward", args: []string{"create"}, writeError: fmt.Errorf("error writing"), + shouldLoad: true, shouldModify: true, shouldFlush: true, shouldForward: false, + modifer: &modiferMock{}, + }, + { + description: "nil modifier forwards on create", + args: []string{"create"}, + shouldLoad: false, + shouldModify: false, + shouldFlush: false, + shouldForward: true, + modifer: nil, }, } @@ -87,7 +107,8 @@ func TestExec(t *testing.T) { runtimeMock, specMock, // TODO: We should test the interactions with the SpecModifier too - nil) + tc.modifer, + ) err := shim.Exec(tc.args) if tc.modifyError != nil || tc.writeError != nil { @@ -96,6 +117,11 @@ func TestExec(t *testing.T) { require.NoError(t, err) } + if tc.shouldLoad { + require.Equal(t, 1, len(specMock.LoadCalls())) + } else { + require.Equal(t, 0, len(specMock.LoadCalls())) + } if tc.shouldModify { require.Equal(t, 1, len(specMock.ModifyCalls())) } else { @@ -114,3 +140,25 @@ func TestExec(t *testing.T) { }) } } + +func TestNilModiferReturnsRuntime(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + runtimeMock := &oci.RuntimeMock{} + specMock := &oci.SpecMock{} + + shim := NewModifyingRuntimeWrapper( + logger, + runtimeMock, + specMock, + nil, + ) + + require.Equal(t, runtimeMock, shim) +} + +type modiferMock struct{} + +func (m modiferMock) Modify(*specs.Spec) error { + return nil +}