diff --git a/pkg/container_config.go b/pkg/container_config.go index bc23661f..3c87c307 100644 --- a/pkg/container_config.go +++ b/pkg/container_config.go @@ -18,6 +18,8 @@ const ( envNVRequireCUDA = envNVRequirePrefix + "CUDA" envNVDisableRequire = "NVIDIA_DISABLE_REQUIRE" envNVVisibleDevices = "NVIDIA_VISIBLE_DEVICES" + envNVMigConfigDevices = "NVIDIA_MIG_CONFIG_DEVICES" + envNVMigMonitorDevices = "NVIDIA_MIG_MONITOR_DEVICES" envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" ) @@ -32,6 +34,8 @@ const ( type nvidiaConfig struct { Devices string + MigConfigDevices string + MigMonitorDevices string DriverCapabilities string Requirements []string DisableRequire bool @@ -178,6 +182,26 @@ func getDevices(env map[string]string) *string { return nil } +func getMigConfigDevices(env map[string]string) *string { + gpuVars := []string{envNVMigConfigDevices} + for _, gpuVar := range gpuVars { + if devices, ok := env[gpuVar]; ok { + return &devices + } + } + return nil +} + +func getMigMonitorDevices(env map[string]string) *string { + gpuVars := []string{envNVMigMonitorDevices} + for _, gpuVar := range gpuVars { + if devices, ok := env[gpuVar]; ok { + return &devices + } + } + return nil +} + func getDriverCapabilities(env map[string]string) *string { if capabilities, ok := env[envNVDriverCapabilities]; ok { return &capabilities @@ -213,6 +237,22 @@ func getNvidiaConfigLegacy(env map[string]string, privileged bool) *nvidiaConfig devices = "" } + var migConfigDevices string + if d := getMigConfigDevices(env); d != nil { + migConfigDevices = *d + } + if !privileged && migConfigDevices != "" { + log.Panicln("cannot set MIG_CONFIG_DEVICES in non privileged container") + } + + var migMonitorDevices string + if d := getMigMonitorDevices(env); d != nil { + migMonitorDevices = *d + } + if !privileged && migMonitorDevices != "" { + log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") + } + var driverCapabilities string if c := getDriverCapabilities(env); c == nil { // Environment variable unset: default to "all". @@ -267,6 +307,22 @@ func getNvidiaConfig(env map[string]string, privileged bool) *nvidiaConfig { devices = "" } + var migConfigDevices string + if d := getMigConfigDevices(env); d != nil { + migConfigDevices = *d + } + if !privileged && migConfigDevices != "" { + log.Panicln("cannot set MIG_CONFIG_DEVICES in non privileged container") + } + + var migMonitorDevices string + if d := getMigMonitorDevices(env); d != nil { + migMonitorDevices = *d + } + if !privileged && migMonitorDevices != "" { + log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") + } + var driverCapabilities string if c := getDriverCapabilities(env); c == nil || len(*c) == 0 { // Environment variable unset or set but empty: use default capability. @@ -286,6 +342,8 @@ func getNvidiaConfig(env map[string]string, privileged bool) *nvidiaConfig { return &nvidiaConfig{ Devices: devices, + MigConfigDevices: migConfigDevices, + MigMonitorDevices: migMonitorDevices, DriverCapabilities: driverCapabilities, Requirements: requirements, DisableRequire: disableRequire, diff --git a/pkg/main.go b/pkg/main.go index 010ff359..13f8197c 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -126,6 +126,12 @@ func doPrestart() { if len(nvidia.Devices) > 0 { args = append(args, fmt.Sprintf("--device=%s", nvidia.Devices)) } + if len(nvidia.MigConfigDevices) > 0 { + args = append(args, fmt.Sprintf("--mig-config=%s", nvidia.MigConfigDevices)) + } + if len(nvidia.MigMonitorDevices) > 0 { + args = append(args, fmt.Sprintf("--mig-monitor=%s", nvidia.MigMonitorDevices)) + } for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") { if len(cap) == 0 {