diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index 42732bfd..d2c8bdda 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -9,9 +9,10 @@ import ( "path/filepath" "strings" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/mod/semver" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) const ( @@ -22,6 +23,7 @@ const ( envNVVisibleDevices = "NVIDIA_VISIBLE_DEVICES" envNVMigConfigDevices = "NVIDIA_MIG_CONFIG_DEVICES" envNVMigMonitorDevices = "NVIDIA_MIG_MONITOR_DEVICES" + envNVImexChannels = "NVIDIA_IMEX_CHANNELS" envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" ) @@ -37,6 +39,7 @@ type nvidiaConfig struct { Devices string MigConfigDevices string MigMonitorDevices string + ImexChannels string DriverCapabilities string // Requirements defines the requirements DSL for the container to run. // This is empty if no specific requirements are needed, or if requirements are @@ -271,6 +274,13 @@ func getMigMonitorDevices(env map[string]string) *string { return nil } +func getImexChannels(env map[string]string) *string { + if chans, ok := env[envNVImexChannels]; ok { + return &chans + } + return nil +} + func (c *HookConfig) getDriverCapabilities(env map[string]string, legacyImage bool) image.DriverCapabilities { // We use the default driver capabilities by default. This is filtered to only include the // supported capabilities @@ -324,6 +334,11 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") } + var imexChannels string + if c := getImexChannels(image); c != nil { + imexChannels = *c + } + driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String() requirements, err := image.GetRequirements() @@ -335,6 +350,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p Devices: devices, MigConfigDevices: migConfigDevices, MigMonitorDevices: migMonitorDevices, + ImexChannels: imexChannels, DriverCapabilities: driverCapabilities, Requirements: requirements, } diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index 30aad846..c004d84e 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -126,6 +126,9 @@ func doPrestart() { if len(nvidia.MigMonitorDevices) > 0 { args = append(args, fmt.Sprintf("--mig-monitor=%s", nvidia.MigMonitorDevices)) } + if len(nvidia.ImexChannels) > 0 { + args = append(args, fmt.Sprintf("--imex-channel=%s", nvidia.ImexChannels)) + } for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") { if len(cap) == 0 {