From e188090a92bf47ebff4ccbcf1b04adf8c46a244d Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 14 Nov 2024 13:26:27 -0700 Subject: [PATCH] Fix NVIDIA_IMEX_CHANNELS handling on legacy images For legacy images (images with a CUDA_VERSION set but no CUDA_REQUIRES set), the default behaviour for device envvars is to treat non-existence as all. This change ensures that the NVIDIA_IMEX_CHANNELS envvar is not treated in the same way, instead returning no devices if the envvar is not set. Signed-off-by: Evan Lezar --- internal/config/image/cuda_image.go | 6 ++++- internal/config/image/cuda_image_test.go | 31 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index 4298e634..d5bbc224 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -292,7 +292,11 @@ func (i CUDA) CDIDevicesFromMounts() []string { // ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image. func (i CUDA) ImexChannelsFromEnvVar() []string { - return i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List() + imexChannels := i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List() + if len(imexChannels) == 1 && imexChannels[0] == "all" { + return nil + } + return imexChannels } // ImexChannelsFromMounts returns the list of IMEX channels requested for the image. diff --git a/internal/config/image/cuda_image_test.go b/internal/config/image/cuda_image_test.go index 39bf30b4..3b77d333 100644 --- a/internal/config/image/cuda_image_test.go +++ b/internal/config/image/cuda_image_test.go @@ -203,6 +203,37 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) { } } +func TestImexChannelsFromEnvVar(t *testing.T) { + testCases := []struct { + description string + env []string + expected []string + }{ + { + description: "no imex channels specified", + }, + { + description: "imex channel specified", + env: []string{ + "NVIDIA_IMEX_CHANNELS=3,4", + }, + expected: []string{"3", "4"}, + }, + } + + for _, tc := range testCases { + for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} { + t.Run(tc.description+id, func(t *testing.T) { + i, err := NewCUDAImageFromEnv(append(baseEnvvars, tc.env...)) + require.NoError(t, err) + + channels := i.ImexChannelsFromEnvVar() + require.EqualValues(t, tc.expected, channels) + }) + } + } +} + func makeTestMounts(paths ...string) []specs.Mount { var mounts []specs.Mount for _, path := range paths {