Return unmodified runtime if specModifier is nil

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2022-04-08 07:38:19 +02:00
parent 99baea9d51
commit 67602b28f9
3 changed files with 55 additions and 6 deletions

View File

@ -48,10 +48,6 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
} }
if specModifier == nil {
logger.Infof("Using low-level runtime with no modification")
return lowLevelRuntime, nil
}
// Create the wrapping runtime with the specified modifier // Create the wrapping runtime with the specified modifier
r := runtime.NewModifyingRuntimeWrapper( r := runtime.NewModifyingRuntimeWrapper(

View File

@ -33,8 +33,13 @@ type modifyingRuntimeWrapper struct {
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil) var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification // NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
// before invoking the wrapped runtime. // before invoking the wrapped runtime. If the modifier is nil, the input runtime is returned.
func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime { func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime {
if modifier == nil {
logger.Infof("Using low-level runtime with no modification")
return runtime
}
rt := modifyingRuntimeWrapper{ rt := modifyingRuntimeWrapper{
logger: logger, logger: logger,
runtime: runtime, runtime: runtime,

View File

@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/opencontainers/runtime-spec/specs-go"
testlog "github.com/sirupsen/logrus/hooks/test" testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -30,41 +31,60 @@ func TestExec(t *testing.T) {
testCases := []struct { testCases := []struct {
description string description string
shouldLoad bool
shouldModify bool shouldModify bool
shouldFlush bool shouldFlush bool
shouldForward bool shouldForward bool
args []string args []string
modifyError error modifyError error
writeError error writeError error
modifer oci.SpecModifier
}{ }{
{ {
description: "no args forwards", description: "no args forwards",
shouldLoad: false,
shouldModify: false, shouldModify: false,
shouldFlush: false, shouldFlush: false,
shouldForward: true, shouldForward: true,
modifer: &modiferMock{},
}, },
{ {
description: "create modifies", description: "create modifies",
args: []string{"create"}, args: []string{"create"},
shouldLoad: true,
shouldModify: true, shouldModify: true,
shouldFlush: true, shouldFlush: true,
shouldForward: true, shouldForward: true,
modifer: &modiferMock{},
}, },
{ {
description: "modify error does not write or forward", description: "modify error does not write or forward",
args: []string{"create"}, args: []string{"create"},
modifyError: fmt.Errorf("error modifying"), modifyError: fmt.Errorf("error modifying"),
shouldLoad: true,
shouldModify: true, shouldModify: true,
shouldFlush: false, shouldFlush: false,
shouldForward: false, shouldForward: false,
modifer: &modiferMock{},
}, },
{ {
description: "write error does not forward", description: "write error does not forward",
args: []string{"create"}, args: []string{"create"},
writeError: fmt.Errorf("error writing"), writeError: fmt.Errorf("error writing"),
shouldLoad: true,
shouldModify: true, shouldModify: true,
shouldFlush: true, shouldFlush: true,
shouldForward: false, shouldForward: false,
modifer: &modiferMock{},
},
{
description: "nil modifier forwards on create",
args: []string{"create"},
shouldLoad: false,
shouldModify: false,
shouldFlush: false,
shouldForward: true,
modifer: nil,
}, },
} }
@ -87,7 +107,8 @@ func TestExec(t *testing.T) {
runtimeMock, runtimeMock,
specMock, specMock,
// TODO: We should test the interactions with the SpecModifier too // TODO: We should test the interactions with the SpecModifier too
nil) tc.modifer,
)
err := shim.Exec(tc.args) err := shim.Exec(tc.args)
if tc.modifyError != nil || tc.writeError != nil { if tc.modifyError != nil || tc.writeError != nil {
@ -96,6 +117,11 @@ func TestExec(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
if tc.shouldLoad {
require.Equal(t, 1, len(specMock.LoadCalls()))
} else {
require.Equal(t, 0, len(specMock.LoadCalls()))
}
if tc.shouldModify { if tc.shouldModify {
require.Equal(t, 1, len(specMock.ModifyCalls())) require.Equal(t, 1, len(specMock.ModifyCalls()))
} else { } else {
@ -114,3 +140,25 @@ func TestExec(t *testing.T) {
}) })
} }
} }
func TestNilModiferReturnsRuntime(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{}
shim := NewModifyingRuntimeWrapper(
logger,
runtimeMock,
specMock,
nil,
)
require.Equal(t, runtimeMock, shim)
}
type modiferMock struct{}
func (m modiferMock) Modify(*specs.Spec) error {
return nil
}