diff --git a/cmd/nvidia-container-toolkit/container_config.go b/cmd/nvidia-container-toolkit/container_config.go index dae4cc7b..b260e70f 100644 --- a/cmd/nvidia-container-toolkit/container_config.go +++ b/cmd/nvidia-container-toolkit/container_config.go @@ -23,6 +23,7 @@ const ( envNVVisibleDevices = "NVIDIA_VISIBLE_DEVICES" envNVMigConfigDevices = "NVIDIA_MIG_CONFIG_DEVICES" envNVMigMonitorDevices = "NVIDIA_MIG_MONITOR_DEVICES" + envNVFabricDevices = "NVIDIA_FABRIC_DEVICES" envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" ) @@ -43,6 +44,7 @@ type nvidiaConfig struct { Devices string MigConfigDevices string MigMonitorDevices string + FabricDevices string DriverCapabilities string Requirements []string DisableRequire bool @@ -316,6 +318,13 @@ func getMigMonitorDevices(env map[string]string) *string { return nil } +func getFabricDevices(env map[string]string) *string { + if devices, ok := env[envNVFabricDevices]; ok { + return &devices + } + return nil +} + func getDriverCapabilities(env map[string]string, legacyImage bool) *string { // Grab a reference to the capabilities from the envvar // if it actually exists in the environment. @@ -394,6 +403,11 @@ func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mou driverCapabilities = *c } + var nvFabricDevices string + if d := getFabricDevices(env); d != nil { + nvFabricDevices = *d + } + requirements := getRequirements(env, legacyImage) // Don't fail on invalid values. @@ -403,6 +417,7 @@ func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mou Devices: devices, MigConfigDevices: migConfigDevices, MigMonitorDevices: migMonitorDevices, + FabricDevices: nvFabricDevices, DriverCapabilities: driverCapabilities, Requirements: requirements, DisableRequire: disableRequire, diff --git a/cmd/nvidia-container-toolkit/container_test.go b/cmd/nvidia-container-toolkit/container_test.go index 3cda6280..ff9e4905 100644 --- a/cmd/nvidia-container-toolkit/container_test.go +++ b/cmd/nvidia-container-toolkit/container_test.go @@ -403,6 +403,30 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedPanic: true, }, + { + description: "fabric devices selected", + env: map[string]string{ + envNVVisibleDevices: "all", + envNVFabricDevices: "all", + }, + expectedConfig: &nvidiaConfig{ + Devices: "all", + FabricDevices: "all", + DriverCapabilities: defaultDriverCapabilities, + }, + }, + { + description: "fabric devices selected empty", + env: map[string]string{ + envNVVisibleDevices: "all", + envNVFabricDevices: "", + }, + expectedConfig: &nvidiaConfig{ + Devices: "all", + FabricDevices: "", + DriverCapabilities: defaultDriverCapabilities, + }, + }, } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { diff --git a/cmd/nvidia-container-toolkit/main.go b/cmd/nvidia-container-toolkit/main.go index 13f8197c..e520039e 100644 --- a/cmd/nvidia-container-toolkit/main.go +++ b/cmd/nvidia-container-toolkit/main.go @@ -132,6 +132,9 @@ func doPrestart() { if len(nvidia.MigMonitorDevices) > 0 { args = append(args, fmt.Sprintf("--mig-monitor=%s", nvidia.MigMonitorDevices)) } + if len(nvidia.FabricDevices) > 0 { + args = append(args, fmt.Sprintf("--fabric-device=%s", nvidia.FabricDevices)) + } for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") { if len(cap) == 0 {