diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index 4298e634..d5bbc224 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -292,7 +292,11 @@ func (i CUDA) CDIDevicesFromMounts() []string { // ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image. func (i CUDA) ImexChannelsFromEnvVar() []string { - return i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List() + imexChannels := i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List() + if len(imexChannels) == 1 && imexChannels[0] == "all" { + return nil + } + return imexChannels } // ImexChannelsFromMounts returns the list of IMEX channels requested for the image. diff --git a/internal/config/image/cuda_image_test.go b/internal/config/image/cuda_image_test.go index 39bf30b4..3b77d333 100644 --- a/internal/config/image/cuda_image_test.go +++ b/internal/config/image/cuda_image_test.go @@ -203,6 +203,37 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) { } } +func TestImexChannelsFromEnvVar(t *testing.T) { + testCases := []struct { + description string + env []string + expected []string + }{ + { + description: "no imex channels specified", + }, + { + description: "imex channel specified", + env: []string{ + "NVIDIA_IMEX_CHANNELS=3,4", + }, + expected: []string{"3", "4"}, + }, + } + + for _, tc := range testCases { + for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} { + t.Run(tc.description+id, func(t *testing.T) { + i, err := NewCUDAImageFromEnv(append(baseEnvvars, tc.env...)) + require.NoError(t, err) + + channels := i.ImexChannelsFromEnvVar() + require.EqualValues(t, tc.expected, channels) + }) + } + } +} + func makeTestMounts(paths ...string) []specs.Mount { var mounts []specs.Mount for _, path := range paths {