From 296d4560b0b96766be4aa97d8185672b1c8c66d5 Mon Sep 17 00:00:00 2001 From: Kevin Klues Date: Wed, 17 Jan 2024 22:38:10 +0000 Subject: [PATCH] Add support for an NVIDIA_IMEX_CHANNELS envvar Signed-off-by: Kevin Klues --- .../container_config.go | 16 ++++++++++++++++ cmd/nvidia-container-runtime-hook/main.go | 3 +++ 2 files changed, 19 insertions(+) diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index cdd1471f..fa39bf2f 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -23,6 +23,7 @@ const ( envNVVisibleDevices = "NVIDIA_VISIBLE_DEVICES" envNVMigConfigDevices = "NVIDIA_MIG_CONFIG_DEVICES" envNVMigMonitorDevices = "NVIDIA_MIG_MONITOR_DEVICES" + envNVImexChannels = "NVIDIA_IMEX_CHANNELS" envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" ) @@ -38,6 +39,7 @@ type nvidiaConfig struct { Devices string MigConfigDevices string MigMonitorDevices string + ImexChannels string DriverCapabilities string // Requirements defines the requirements DSL for the container to run. // This is empty if no specific requirements are needed, or if requirements are @@ -274,6 +276,14 @@ func getMigDevices(image image.CUDA, envvar string) *string { return &devices } +func getImexChannels(image image.CUDA) *string { + if !image.HasEnvvar(envNVImexChannels) { + return nil + } + chans := image.Getenv(envNVImexChannels) + return &chans +} + func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities { // We use the default driver capabilities by default. This is filtered to only include the // supported capabilities @@ -328,6 +338,11 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") } + var imexChannels string + if c := getImexChannels(image); c != nil { + imexChannels = *c + } + driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String() requirements, err := image.GetRequirements() @@ -339,6 +354,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p Devices: devices, MigConfigDevices: migConfigDevices, MigMonitorDevices: migMonitorDevices, + ImexChannels: imexChannels, DriverCapabilities: driverCapabilities, Requirements: requirements, } diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index ac807074..faaf0b51 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -126,6 +126,9 @@ func doPrestart() { if len(nvidia.MigMonitorDevices) > 0 { args = append(args, fmt.Sprintf("--mig-monitor=%s", nvidia.MigMonitorDevices)) } + if len(nvidia.ImexChannels) > 0 { + args = append(args, fmt.Sprintf("--imex-channel=%s", nvidia.ImexChannels)) + } for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") { if len(cap) == 0 {