nvidia-container-toolkit/pkg/runtime/runtime_modifier_test.go

161 lines
3.4 KiB
Go
Raw Normal View History

/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/modify"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestRuntimeModifier(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
args []string
shouldModify bool
}{
{},
{
args: []string{"create"},
shouldModify: true,
},
}
for _, tc := range testCases {
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec(tc.args)
require.NoError(t, err)
expectedCalls := 0
if tc.shouldModify {
expectedCalls = 1
}
require.Len(t, specMock.LoadCalls(), expectedCalls)
require.Len(t, modifierMock.ModifyCalls(), expectedCalls)
require.Len(t, specMock.FlushCalls(), expectedCalls)
require.Len(t, runtimeMock.ExecCalls(), 1)
}
}
func TestRuntimeModiferWithLoadError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{
LoadFunc: specErrorFunc,
}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 0)
require.Len(t, specMock.FlushCalls(), 0)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func TestRuntimeModiferWithFlushError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{
FlushFunc: specErrorFunc,
}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 1)
require.Len(t, specMock.FlushCalls(), 1)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func TestRuntimeModiferWithModifyError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{}
modifierMock := &modify.ModifierMock{
ModifyFunc: modifierErrorFunc,
}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 1)
require.Len(t, specMock.FlushCalls(), 0)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func specErrorFunc() error {
return fmt.Errorf("error")
}
func modifierErrorFunc(oci.Spec) error {
return fmt.Errorf("error")
}