mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Merge branch 'check-for-nil-modifier' into 'master'
Return unmodified runtime if specModifier is nil See merge request nvidia/container-toolkit/container-toolkit!127
This commit is contained in:
		
						commit
						0de7491ce3
					
				| @ -48,10 +48,6 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [ | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) | 		return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) | ||||||
| 	} | 	} | ||||||
| 	if specModifier == nil { |  | ||||||
| 		logger.Infof("Using low-level runtime with no modification") |  | ||||||
| 		return lowLevelRuntime, nil |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	// Create the wrapping runtime with the specified modifier
 | 	// Create the wrapping runtime with the specified modifier
 | ||||||
| 	r := runtime.NewModifyingRuntimeWrapper( | 	r := runtime.NewModifyingRuntimeWrapper( | ||||||
|  | |||||||
| @ -33,8 +33,13 @@ type modifyingRuntimeWrapper struct { | |||||||
| var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil) | var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil) | ||||||
| 
 | 
 | ||||||
| // NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
 | // NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
 | ||||||
| // before invoking the wrapped runtime.
 | // before invoking the wrapped runtime. If the modifier is nil, the input runtime is returned.
 | ||||||
| func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime { | func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime { | ||||||
|  | 	if modifier == nil { | ||||||
|  | 		logger.Infof("Using low-level runtime with no modification") | ||||||
|  | 		return runtime | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	rt := modifyingRuntimeWrapper{ | 	rt := modifyingRuntimeWrapper{ | ||||||
| 		logger:   logger, | 		logger:   logger, | ||||||
| 		runtime:  runtime, | 		runtime:  runtime, | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" | 	"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" | ||||||
|  | 	"github.com/opencontainers/runtime-spec/specs-go" | ||||||
| 	testlog "github.com/sirupsen/logrus/hooks/test" | 	testlog "github.com/sirupsen/logrus/hooks/test" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| ) | ) | ||||||
| @ -30,41 +31,60 @@ func TestExec(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	testCases := []struct { | 	testCases := []struct { | ||||||
| 		description   string | 		description   string | ||||||
|  | 		shouldLoad    bool | ||||||
| 		shouldModify  bool | 		shouldModify  bool | ||||||
| 		shouldFlush   bool | 		shouldFlush   bool | ||||||
| 		shouldForward bool | 		shouldForward bool | ||||||
| 		args          []string | 		args          []string | ||||||
| 		modifyError   error | 		modifyError   error | ||||||
| 		writeError    error | 		writeError    error | ||||||
|  | 		modifer       oci.SpecModifier | ||||||
| 	}{ | 	}{ | ||||||
| 		{ | 		{ | ||||||
| 			description:   "no args forwards", | 			description:   "no args forwards", | ||||||
|  | 			shouldLoad:    false, | ||||||
| 			shouldModify:  false, | 			shouldModify:  false, | ||||||
| 			shouldFlush:   false, | 			shouldFlush:   false, | ||||||
| 			shouldForward: true, | 			shouldForward: true, | ||||||
|  | 			modifer:       &modiferMock{}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			description:   "create modifies", | 			description:   "create modifies", | ||||||
| 			args:          []string{"create"}, | 			args:          []string{"create"}, | ||||||
|  | 			shouldLoad:    true, | ||||||
| 			shouldModify:  true, | 			shouldModify:  true, | ||||||
| 			shouldFlush:   true, | 			shouldFlush:   true, | ||||||
| 			shouldForward: true, | 			shouldForward: true, | ||||||
|  | 			modifer:       &modiferMock{}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			description:   "modify error does not write or forward", | 			description:   "modify error does not write or forward", | ||||||
| 			args:          []string{"create"}, | 			args:          []string{"create"}, | ||||||
| 			modifyError:   fmt.Errorf("error modifying"), | 			modifyError:   fmt.Errorf("error modifying"), | ||||||
|  | 			shouldLoad:    true, | ||||||
| 			shouldModify:  true, | 			shouldModify:  true, | ||||||
| 			shouldFlush:   false, | 			shouldFlush:   false, | ||||||
| 			shouldForward: false, | 			shouldForward: false, | ||||||
|  | 			modifer:       &modiferMock{}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			description:   "write error does not forward", | 			description:   "write error does not forward", | ||||||
| 			args:          []string{"create"}, | 			args:          []string{"create"}, | ||||||
| 			writeError:    fmt.Errorf("error writing"), | 			writeError:    fmt.Errorf("error writing"), | ||||||
|  | 			shouldLoad:    true, | ||||||
| 			shouldModify:  true, | 			shouldModify:  true, | ||||||
| 			shouldFlush:   true, | 			shouldFlush:   true, | ||||||
| 			shouldForward: false, | 			shouldForward: false, | ||||||
|  | 			modifer:       &modiferMock{}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description:   "nil modifier forwards on create", | ||||||
|  | 			args:          []string{"create"}, | ||||||
|  | 			shouldLoad:    false, | ||||||
|  | 			shouldModify:  false, | ||||||
|  | 			shouldFlush:   false, | ||||||
|  | 			shouldForward: true, | ||||||
|  | 			modifer:       nil, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -87,7 +107,8 @@ func TestExec(t *testing.T) { | |||||||
| 				runtimeMock, | 				runtimeMock, | ||||||
| 				specMock, | 				specMock, | ||||||
| 				// TODO: We should test the interactions with the SpecModifier too
 | 				// TODO: We should test the interactions with the SpecModifier too
 | ||||||
| 				nil) | 				tc.modifer, | ||||||
|  | 			) | ||||||
| 
 | 
 | ||||||
| 			err := shim.Exec(tc.args) | 			err := shim.Exec(tc.args) | ||||||
| 			if tc.modifyError != nil || tc.writeError != nil { | 			if tc.modifyError != nil || tc.writeError != nil { | ||||||
| @ -96,6 +117,11 @@ func TestExec(t *testing.T) { | |||||||
| 				require.NoError(t, err) | 				require.NoError(t, err) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | 			if tc.shouldLoad { | ||||||
|  | 				require.Equal(t, 1, len(specMock.LoadCalls())) | ||||||
|  | 			} else { | ||||||
|  | 				require.Equal(t, 0, len(specMock.LoadCalls())) | ||||||
|  | 			} | ||||||
| 			if tc.shouldModify { | 			if tc.shouldModify { | ||||||
| 				require.Equal(t, 1, len(specMock.ModifyCalls())) | 				require.Equal(t, 1, len(specMock.ModifyCalls())) | ||||||
| 			} else { | 			} else { | ||||||
| @ -114,3 +140,25 @@ func TestExec(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestNilModiferReturnsRuntime(t *testing.T) { | ||||||
|  | 	logger, _ := testlog.NewNullLogger() | ||||||
|  | 
 | ||||||
|  | 	runtimeMock := &oci.RuntimeMock{} | ||||||
|  | 	specMock := &oci.SpecMock{} | ||||||
|  | 
 | ||||||
|  | 	shim := NewModifyingRuntimeWrapper( | ||||||
|  | 		logger, | ||||||
|  | 		runtimeMock, | ||||||
|  | 		specMock, | ||||||
|  | 		nil, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	require.Equal(t, runtimeMock, shim) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type modiferMock struct{} | ||||||
|  | 
 | ||||||
|  | func (m modiferMock) Modify(*specs.Spec) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user