Add support for NVIDIA_FABRIC_DEVICES

This change adds support for the NVIDIA_FABRIC_DEVICES envvar. The (non-empty)
value of this envvar is passed to the NVIDIA Container CLI using the --fabric-device
command line flag and allows for nvswitch and nvlink devices to be mounted
into the container.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2021-07-22 17:44:59 +02:00
parent faf0df66c7
commit f828efcf64
3 changed files with 42 additions and 0 deletions

View File

@ -23,6 +23,7 @@ const (
envNVVisibleDevices = "NVIDIA_VISIBLE_DEVICES" envNVVisibleDevices = "NVIDIA_VISIBLE_DEVICES"
envNVMigConfigDevices = "NVIDIA_MIG_CONFIG_DEVICES" envNVMigConfigDevices = "NVIDIA_MIG_CONFIG_DEVICES"
envNVMigMonitorDevices = "NVIDIA_MIG_MONITOR_DEVICES" envNVMigMonitorDevices = "NVIDIA_MIG_MONITOR_DEVICES"
envNVFabricDevices = "NVIDIA_FABRIC_DEVICES"
envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES"
) )
@ -43,6 +44,7 @@ type nvidiaConfig struct {
Devices string Devices string
MigConfigDevices string MigConfigDevices string
MigMonitorDevices string MigMonitorDevices string
FabricDevices string
DriverCapabilities string DriverCapabilities string
Requirements []string Requirements []string
DisableRequire bool DisableRequire bool
@ -316,6 +318,13 @@ func getMigMonitorDevices(env map[string]string) *string {
return nil return nil
} }
func getFabricDevices(env map[string]string) *string {
if devices, ok := env[envNVFabricDevices]; ok {
return &devices
}
return nil
}
func getDriverCapabilities(env map[string]string, legacyImage bool) *string { func getDriverCapabilities(env map[string]string, legacyImage bool) *string {
// Grab a reference to the capabilities from the envvar // Grab a reference to the capabilities from the envvar
// if it actually exists in the environment. // if it actually exists in the environment.
@ -394,6 +403,11 @@ func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mou
driverCapabilities = *c driverCapabilities = *c
} }
var nvFabricDevices string
if d := getFabricDevices(env); d != nil {
nvFabricDevices = *d
}
requirements := getRequirements(env, legacyImage) requirements := getRequirements(env, legacyImage)
// Don't fail on invalid values. // Don't fail on invalid values.
@ -403,6 +417,7 @@ func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mou
Devices: devices, Devices: devices,
MigConfigDevices: migConfigDevices, MigConfigDevices: migConfigDevices,
MigMonitorDevices: migMonitorDevices, MigMonitorDevices: migMonitorDevices,
FabricDevices: nvFabricDevices,
DriverCapabilities: driverCapabilities, DriverCapabilities: driverCapabilities,
Requirements: requirements, Requirements: requirements,
DisableRequire: disableRequire, DisableRequire: disableRequire,

View File

@ -403,6 +403,30 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false, privileged: false,
expectedPanic: true, expectedPanic: true,
}, },
{
description: "fabric devices selected",
env: map[string]string{
envNVVisibleDevices: "all",
envNVFabricDevices: "all",
},
expectedConfig: &nvidiaConfig{
Devices: "all",
FabricDevices: "all",
DriverCapabilities: defaultDriverCapabilities,
},
},
{
description: "fabric devices selected empty",
env: map[string]string{
envNVVisibleDevices: "all",
envNVFabricDevices: "",
},
expectedConfig: &nvidiaConfig{
Devices: "all",
FabricDevices: "",
DriverCapabilities: defaultDriverCapabilities,
},
},
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {

View File

@ -132,6 +132,9 @@ func doPrestart() {
if len(nvidia.MigMonitorDevices) > 0 { if len(nvidia.MigMonitorDevices) > 0 {
args = append(args, fmt.Sprintf("--mig-monitor=%s", nvidia.MigMonitorDevices)) args = append(args, fmt.Sprintf("--mig-monitor=%s", nvidia.MigMonitorDevices))
} }
if len(nvidia.FabricDevices) > 0 {
args = append(args, fmt.Sprintf("--fabric-device=%s", nvidia.FabricDevices))
}
for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") { for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") {
if len(cap) == 0 { if len(cap) == 0 {