diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index e982c695..ac34fb93 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -21,7 +21,7 @@ type nvidiaConfig struct { Devices []string MigConfigDevices string MigMonitorDevices string - ImexChannels string + ImexChannels []string DriverCapabilities string // Requirements defines the requirements DSL for the container to run. // This is empty if no specific requirements are needed, or if requirements are @@ -197,12 +197,24 @@ func getMigDevices(image image.CUDA, envvar string) *string { return &devices } -func getImexChannels(i image.CUDA) *string { - if !i.HasEnvvar(image.EnvVarNvidiaImexChannels) { +func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool) []string { + // If enabled, try and get the device list from volume mounts first + if hookConfig.AcceptDeviceListAsVolumeMounts { + devices := image.ImexChannelsFromMounts() + if len(devices) > 0 { + return devices + } + } + devices := image.ImexChannelsFromEnvVar() + if len(devices) == 0 { return nil } - chans := i.Getenv(image.EnvVarNvidiaImexChannels) - return &chans + + if privileged || hookConfig.AcceptEnvvarUnprivileged { + return devices + } + + return nil } func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities { @@ -257,10 +269,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool) log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") } - var imexChannels string - if c := getImexChannels(image); c != nil { - imexChannels = *c - } + imexChannels := getImexChannels(hookConfig, image, privileged) driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String() diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index ad3e208d..a0cfe3ec 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -129,8 +129,8 @@ func doPrestart() { if len(nvidia.MigMonitorDevices) > 0 { args = append(args, fmt.Sprintf("--mig-monitor=%s", nvidia.MigMonitorDevices)) } - if len(nvidia.ImexChannels) > 0 { - args = append(args, fmt.Sprintf("--imex-channel=%s", nvidia.ImexChannels)) + if imexString := strings.Join(nvidia.ImexChannels, ","); len(imexString) > 0 { + args = append(args, fmt.Sprintf("--imex-channel=%s", imexString)) } for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") { diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index f6c0979b..4298e634 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -30,7 +30,8 @@ import ( const ( DeviceListAsVolumeMountsRoot = "/var/run/nvidia-container-devices" - volumeMountDevicePrefixCDI = "cdi/" + volumeMountDevicePrefixCDI = "cdi/" + volumeMountDevicePrefixImex = "imex/" ) // CUDA represents a CUDA image that can be used for GPU computing. This wraps @@ -225,7 +226,10 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string { func (i CUDA) VisibleDevicesFromMounts() []string { var devices []string for _, device := range i.DevicesFromMounts() { - if strings.HasPrefix(device, volumeMountDevicePrefixCDI) { + switch { + case strings.HasPrefix(device, volumeMountDevicePrefixCDI): + continue + case strings.HasPrefix(device, volumeMountDevicePrefixImex): continue } devices = append(devices, device) @@ -286,6 +290,19 @@ func (i CUDA) CDIDevicesFromMounts() []string { return devices } -func (i CUDA) IsEnabled(envvar string) bool { - return i.Getenv(envvar) == "enabled" +// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image. +func (i CUDA) ImexChannelsFromEnvVar() []string { + return i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List() +} + +// ImexChannelsFromMounts returns the list of IMEX channels requested for the image. +func (i CUDA) ImexChannelsFromMounts() []string { + var channels []string + for _, mountDevice := range i.DevicesFromMounts() { + if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) { + continue + } + channels = append(channels, strings.TrimPrefix(mountDevice, volumeMountDevicePrefixImex)) + } + return channels } diff --git a/internal/config/image/cuda_image_test.go b/internal/config/image/cuda_image_test.go index 9d9d5044..39bf30b4 100644 --- a/internal/config/image/cuda_image_test.go +++ b/internal/config/image/cuda_image_test.go @@ -189,6 +189,11 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) { mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"), expectedDevices: []string{"GPU0", "GPU1"}, }, + { + description: "imex devices are ignored", + mounts: makeTestMounts("GPU0", "imex/0", "GPU1"), + expectedDevices: []string{"GPU0", "GPU1"}, + }, } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) {