diff --git a/pkg/container_test.go b/pkg/container_test.go index c4748b6b..a64749fe 100644 --- a/pkg/container_test.go +++ b/pkg/container_test.go @@ -1,6 +1,7 @@ package main import ( + "path/filepath" "reflect" "testing" ) @@ -450,6 +451,93 @@ func TestGetNvidiaConfig(t *testing.T) { } } +func TestGetDevicesFromMounts(t *testing.T) { + var tests = []struct { + description string + root string + mounts []Mount + expectedDevices *string + }{ + { + description: "No mounts", + root: defaultDeviceListVolumeMount, + mounts: nil, + expectedDevices: nil, + }, + { + description: "Host path is not /dev/null", + root: defaultDeviceListVolumeMount, + mounts: []Mount{ + { + Source: "/not/dev/null", + Destination: filepath.Join(defaultDeviceListVolumeMount, "GPU0"), + }, + }, + expectedDevices: nil, + }, + { + description: "Container path is not prefixed by 'root'", + root: defaultDeviceListVolumeMount, + mounts: []Mount{ + { + Source: "/dev/null", + Destination: filepath.Join("/other/prefix", "GPU0"), + }, + }, + expectedDevices: nil, + }, + { + description: "Container path is only 'root'", + root: defaultDeviceListVolumeMount, + mounts: []Mount{ + { + Source: "/dev/null", + Destination: defaultDeviceListVolumeMount, + }, + }, + expectedDevices: nil, + }, + { + description: "Discover 2 devices", + root: defaultDeviceListVolumeMount, + mounts: []Mount{ + { + Source: "/dev/null", + Destination: filepath.Join(defaultDeviceListVolumeMount, "GPU0"), + }, + { + Source: "/dev/null", + Destination: filepath.Join(defaultDeviceListVolumeMount, "GPU1"), + }, + }, + expectedDevices: &[]string{"GPU0,GPU1"}[0], + }, + { + description: "Discover 2 devices with slashes in the name", + root: defaultDeviceListVolumeMount, + mounts: []Mount{ + { + Source: "/dev/null", + Destination: filepath.Join(defaultDeviceListVolumeMount, "GPU0-MIG0/0/1"), + }, + { + Source: "/dev/null", + Destination: filepath.Join(defaultDeviceListVolumeMount, "GPU1-MIG0/0/1"), + }, + }, + expectedDevices: &[]string{"GPU0-MIG0/0/1,GPU1-MIG0/0/1"}[0], + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + devices := getDevicesFromMounts(tc.root, tc.mounts) + if !reflect.DeepEqual(devices, tc.expectedDevices) { + t.Errorf("Unexpected devices (got: %v, wanted: %v)", *devices, *tc.expectedDevices) + } + }) + } +} + func elementsMatch(slice0, slice1 []string) bool { map0 := make(map[string]int) map1 := make(map[string]int)