diff --git a/pkg/container_test.go b/pkg/container_test.go index a64749fe..98c40d2c 100644 --- a/pkg/container_test.go +++ b/pkg/container_test.go @@ -538,6 +538,86 @@ func TestGetDevicesFromMounts(t *testing.T) { } } +func TestDeviceListSourcePriority(t *testing.T) { + var tests = []struct { + description string + mountDevices []Mount + envvarDevices string + privileged bool + acceptUnprivileged bool + expectedDevices *string + expectedPanic bool + }{ + { + description: "Mount devices, unprivileged, no accept unprivileged", + mountDevices: []Mount{ + { + Source: "/dev/null", + Destination: filepath.Join(defaultDeviceListVolumeMount, "GPU0"), + }, + { + Source: "/dev/null", + Destination: filepath.Join(defaultDeviceListVolumeMount, "GPU1"), + }, + }, + envvarDevices: "GPU2,GPU3", + privileged: false, + acceptUnprivileged: false, + expectedDevices: &[]string{"GPU0,GPU1"}[0], + }, + { + description: "No mount devices, unprivileged, no accept unprivileged", + mountDevices: nil, + envvarDevices: "GPU0,GPU1", + privileged: false, + acceptUnprivileged: false, + expectedPanic: true, + }, + { + description: "No mount devices, privileged, no accept unprivileged", + mountDevices: nil, + envvarDevices: "GPU0,GPU1", + privileged: true, + acceptUnprivileged: false, + expectedDevices: &[]string{"GPU0,GPU1"}[0], + }, + { + description: "No mount devices, unprivileged, accept unprivileged", + mountDevices: nil, + envvarDevices: "GPU0,GPU1", + privileged: false, + acceptUnprivileged: true, + expectedDevices: &[]string{"GPU0,GPU1"}[0], + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + // Wrap the call to getDevices() in a closure. + var devices *string + getDevices := func() { + env := map[string]string{ + envNVVisibleDevices: tc.envvarDevices, + } + hookConfig := getDefaultHookConfig() + hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged + devices = getDevices(&hookConfig, env, tc.mountDevices, tc.privileged, false) + } + + // For any tests that are expected to panic, make sure they do. + if tc.expectedPanic { + mustPanic(t, getDevices) + return + } + + // For all other tests, just grab the devices and check the results + getDevices() + if !reflect.DeepEqual(devices, tc.expectedDevices) { + t.Errorf("Unexpected devices (got: %v, wanted: %v)", *devices, *tc.expectedDevices) + } + }) + } +} + func elementsMatch(slice0, slice1 []string) bool { map0 := make(map[string]int) map1 := make(map[string]int)