/* # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. */ package main import ( "fmt" "strings" "testing" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/opencontainers/runtime-spec/specs-go" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) func TestAddNvidiaHook(t *testing.T) { logger, logHook := testlog.NewNullLogger() shim := nvidiaContainerRuntime{ logger: logger, } testCases := []struct { spec *specs.Spec errorPrefix string shouldNotAdd bool }{ { spec: &specs.Spec{}, }, { spec: &specs.Spec{ Hooks: &specs.Hooks{}, }, }, { spec: &specs.Spec{ Hooks: &specs.Hooks{ Prestart: []specs.Hook{{ Path: "some-hook", }}, }, }, }, { spec: &specs.Spec{ Hooks: &specs.Hooks{ Prestart: []specs.Hook{{ Path: "nvidia-container-runtime-hook", }}, }, }, shouldNotAdd: true, }, } for i, tc := range testCases { logHook.Reset() var numPrestartHooks int if tc.spec.Hooks != nil { numPrestartHooks = len(tc.spec.Hooks.Prestart) } err := shim.addNVIDIAHook(tc.spec) if tc.errorPrefix == "" { require.NoErrorf(t, err, "%d: %v", i, tc) } else { require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc) require.NotNilf(t, tc.spec.Hooks, "%d: %v", i, tc) require.Equalf(t, 1, nvidiaHookCount(tc.spec.Hooks), "%d: %v", i, tc) if tc.shouldNotAdd { require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc) } else { require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc) nvidiaHook := tc.spec.Hooks.Poststart[len(tc.spec.Hooks.Poststart)-1] // TODO: This assumes that the hook has been set up in the makefile expectedPath := "/usr/bin/nvidia-container-runtime-hook" require.Equalf(t, expectedPath, nvidiaHook.Path, "%d: %v", i, tc) require.Equalf(t, []string{expectedPath, "prestart"}, nvidiaHook.Args, "%d: %v", i, tc) require.Emptyf(t, nvidiaHook.Env, "%d: %v", i, tc) require.Nilf(t, nvidiaHook.Timeout, "%d: %v", i, tc) } } } } func TestNvidiaContainerRuntime(t *testing.T) { logger, hook := testlog.NewNullLogger() testCases := []struct { shim nvidiaContainerRuntime shouldModify bool args []string modifyError error writeError error }{ { shim: nvidiaContainerRuntime{}, shouldModify: false, }, { shim: nvidiaContainerRuntime{}, args: []string{"create"}, shouldModify: true, }, { shim: nvidiaContainerRuntime{}, args: []string{"--bundle=create"}, shouldModify: false, }, { shim: nvidiaContainerRuntime{}, args: []string{"--bundle", "create"}, shouldModify: false, }, { shim: nvidiaContainerRuntime{}, args: []string{"create"}, shouldModify: true, }, { shim: nvidiaContainerRuntime{}, args: []string{"create"}, modifyError: fmt.Errorf("error modifying"), shouldModify: true, }, { shim: nvidiaContainerRuntime{}, args: []string{"create"}, writeError: fmt.Errorf("error writing"), shouldModify: true, }, } for i, tc := range testCases { tc.shim.logger = logger hook.Reset() ociMock := &oci.SpecMock{ ModifyFunc: func(specModifier oci.SpecModifier) error { return tc.modifyError }, FlushFunc: func() error { return tc.writeError }, } require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc) tc.shim.ociSpec = ociMock tc.shim.runtime = &MockShim{} err := tc.shim.Exec(tc.args) if tc.modifyError != nil || tc.writeError != nil { require.Error(t, err, "%d: %v", i, tc) } else { require.NoError(t, err, "%d: %v", i, tc) } if tc.shouldModify { require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc) } else { require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc) } writeExpected := tc.shouldModify && tc.modifyError == nil if writeExpected { require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc) } else { require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc) } } } type MockShim struct { called bool args []string returnError error } func (m *MockShim) Exec(args []string) error { m.called = true m.args = args return m.returnError }