mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Return unmodified runtime if specModifier is nil
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
		
							parent
							
								
									99baea9d51
								
							
						
					
					
						commit
						67602b28f9
					
				| @ -48,10 +48,6 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [ | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) | ||||
| 	} | ||||
| 	if specModifier == nil { | ||||
| 		logger.Infof("Using low-level runtime with no modification") | ||||
| 		return lowLevelRuntime, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Create the wrapping runtime with the specified modifier
 | ||||
| 	r := runtime.NewModifyingRuntimeWrapper( | ||||
|  | ||||
| @ -33,8 +33,13 @@ type modifyingRuntimeWrapper struct { | ||||
| var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil) | ||||
| 
 | ||||
| // NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
 | ||||
| // before invoking the wrapped runtime.
 | ||||
| // before invoking the wrapped runtime. If the modifier is nil, the input runtime is returned.
 | ||||
| func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime { | ||||
| 	if modifier == nil { | ||||
| 		logger.Infof("Using low-level runtime with no modification") | ||||
| 		return runtime | ||||
| 	} | ||||
| 
 | ||||
| 	rt := modifyingRuntimeWrapper{ | ||||
| 		logger:   logger, | ||||
| 		runtime:  runtime, | ||||
|  | ||||
| @ -21,6 +21,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" | ||||
| 	"github.com/opencontainers/runtime-spec/specs-go" | ||||
| 	testlog "github.com/sirupsen/logrus/hooks/test" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| @ -30,41 +31,60 @@ func TestExec(t *testing.T) { | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		description   string | ||||
| 		shouldLoad    bool | ||||
| 		shouldModify  bool | ||||
| 		shouldFlush   bool | ||||
| 		shouldForward bool | ||||
| 		args          []string | ||||
| 		modifyError   error | ||||
| 		writeError    error | ||||
| 		modifer       oci.SpecModifier | ||||
| 	}{ | ||||
| 		{ | ||||
| 			description:   "no args forwards", | ||||
| 			shouldLoad:    false, | ||||
| 			shouldModify:  false, | ||||
| 			shouldFlush:   false, | ||||
| 			shouldForward: true, | ||||
| 			modifer:       &modiferMock{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "create modifies", | ||||
| 			args:          []string{"create"}, | ||||
| 			shouldLoad:    true, | ||||
| 			shouldModify:  true, | ||||
| 			shouldFlush:   true, | ||||
| 			shouldForward: true, | ||||
| 			modifer:       &modiferMock{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "modify error does not write or forward", | ||||
| 			args:          []string{"create"}, | ||||
| 			modifyError:   fmt.Errorf("error modifying"), | ||||
| 			shouldLoad:    true, | ||||
| 			shouldModify:  true, | ||||
| 			shouldFlush:   false, | ||||
| 			shouldForward: false, | ||||
| 			modifer:       &modiferMock{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "write error does not forward", | ||||
| 			args:          []string{"create"}, | ||||
| 			writeError:    fmt.Errorf("error writing"), | ||||
| 			shouldLoad:    true, | ||||
| 			shouldModify:  true, | ||||
| 			shouldFlush:   true, | ||||
| 			shouldForward: false, | ||||
| 			modifer:       &modiferMock{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "nil modifier forwards on create", | ||||
| 			args:          []string{"create"}, | ||||
| 			shouldLoad:    false, | ||||
| 			shouldModify:  false, | ||||
| 			shouldFlush:   false, | ||||
| 			shouldForward: true, | ||||
| 			modifer:       nil, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| @ -87,7 +107,8 @@ func TestExec(t *testing.T) { | ||||
| 				runtimeMock, | ||||
| 				specMock, | ||||
| 				// TODO: We should test the interactions with the SpecModifier too
 | ||||
| 				nil) | ||||
| 				tc.modifer, | ||||
| 			) | ||||
| 
 | ||||
| 			err := shim.Exec(tc.args) | ||||
| 			if tc.modifyError != nil || tc.writeError != nil { | ||||
| @ -96,6 +117,11 @@ func TestExec(t *testing.T) { | ||||
| 				require.NoError(t, err) | ||||
| 			} | ||||
| 
 | ||||
| 			if tc.shouldLoad { | ||||
| 				require.Equal(t, 1, len(specMock.LoadCalls())) | ||||
| 			} else { | ||||
| 				require.Equal(t, 0, len(specMock.LoadCalls())) | ||||
| 			} | ||||
| 			if tc.shouldModify { | ||||
| 				require.Equal(t, 1, len(specMock.ModifyCalls())) | ||||
| 			} else { | ||||
| @ -114,3 +140,25 @@ func TestExec(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNilModiferReturnsRuntime(t *testing.T) { | ||||
| 	logger, _ := testlog.NewNullLogger() | ||||
| 
 | ||||
| 	runtimeMock := &oci.RuntimeMock{} | ||||
| 	specMock := &oci.SpecMock{} | ||||
| 
 | ||||
| 	shim := NewModifyingRuntimeWrapper( | ||||
| 		logger, | ||||
| 		runtimeMock, | ||||
| 		specMock, | ||||
| 		nil, | ||||
| 	) | ||||
| 
 | ||||
| 	require.Equal(t, runtimeMock, shim) | ||||
| } | ||||
| 
 | ||||
| type modiferMock struct{} | ||||
| 
 | ||||
| func (m modiferMock) Modify(*specs.Spec) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user