diff --git a/nvidia-container-toolkit/container_config.go b/nvidia-container-toolkit/container_config.go index ef46c72a..26f23341 100644 --- a/nvidia-container-toolkit/container_config.go +++ b/nvidia-container-toolkit/container_config.go @@ -13,21 +13,24 @@ import ( var envSwarmGPU *string const ( + envCUDAVersion = "CUDA_VERSION" envNVRequirePrefix = "NVIDIA_REQUIRE_" - envLegacyCUDAVersion = "CUDA_VERSION" envNVRequireCUDA = envNVRequirePrefix + "CUDA" - envNVGPU = "NVIDIA_VISIBLE_DEVICES" - envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" - defaultCapability = "utility" - allCapabilities = "compute,compat32,graphics,utility,video,display" envNVDisableRequire = "NVIDIA_DISABLE_REQUIRE" + envNVVisibleDevices = "NVIDIA_VISIBLE_DEVICES" + envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" +) + +const ( + allDriverCapabilities = "compute,compat32,graphics,utility,video,display" + defaultDriverCapabilities = "utility" ) type nvidiaConfig struct { - Devices string - Capabilities string - Requirements []string - DisableRequire bool + Devices string + DriverCapabilities string + Requirements []string + DisableRequire bool } type containerConfig struct { @@ -111,7 +114,7 @@ func loadSpec(path string) (spec *Spec) { } func getDevices(env map[string]string) *string { - gpuVars := []string{envNVGPU} + gpuVars := []string{envNVVisibleDevices} if envSwarmGPU != nil { // The Swarm resource has higher precedence. gpuVars = append([]string{*envSwarmGPU}, gpuVars...) @@ -125,7 +128,7 @@ func getDevices(env map[string]string) *string { return nil } -func getCapabilities(env map[string]string) *string { +func getDriverCapabilities(env map[string]string) *string { if capabilities, ok := env[envNVDriverCapabilities]; ok { return &capabilities } @@ -160,24 +163,24 @@ func getNvidiaConfigLegacy(env map[string]string) *nvidiaConfig { devices = "" } - var capabilities string - if c := getCapabilities(env); c == nil { + var driverCapabilities string + if c := getDriverCapabilities(env); c == nil { // Environment variable unset: default to "all". - capabilities = allCapabilities + driverCapabilities = allDriverCapabilities } else if len(*c) == 0 { // Environment variable empty: use default capability. - capabilities = defaultCapability + driverCapabilities = defaultDriverCapabilities } else { // Environment variable non-empty. - capabilities = *c + driverCapabilities = *c } - if capabilities == "all" { - capabilities = allCapabilities + if driverCapabilities == "all" { + driverCapabilities = allDriverCapabilities } requirements := getRequirements(env) - vmaj, vmin, _ := parseCudaVersion(env[envLegacyCUDAVersion]) + vmaj, vmin, _ := parseCudaVersion(env[envCUDAVersion]) cudaRequire := fmt.Sprintf("cuda>=%d.%d", vmaj, vmin) requirements = append(requirements, cudaRequire) @@ -185,15 +188,15 @@ func getNvidiaConfigLegacy(env map[string]string) *nvidiaConfig { disableRequire, _ := strconv.ParseBool(env[envNVDisableRequire]) return &nvidiaConfig{ - Devices: devices, - Capabilities: capabilities, - Requirements: requirements, - DisableRequire: disableRequire, + Devices: devices, + DriverCapabilities: driverCapabilities, + Requirements: requirements, + DisableRequire: disableRequire, } } func getNvidiaConfig(env map[string]string) *nvidiaConfig { - legacyCudaVersion := env[envLegacyCUDAVersion] + legacyCudaVersion := env[envCUDAVersion] cudaRequire := env[envNVRequireCUDA] if len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 { // Legacy CUDA image detected. @@ -212,16 +215,16 @@ func getNvidiaConfig(env map[string]string) *nvidiaConfig { devices = "" } - var capabilities string - if c := getCapabilities(env); c == nil || len(*c) == 0 { + var driverCapabilities string + if c := getDriverCapabilities(env); c == nil || len(*c) == 0 { // Environment variable unset or set but empty: use default capability. - capabilities = defaultCapability + driverCapabilities = defaultDriverCapabilities } else { // Environment variable set and non-empty. - capabilities = *c + driverCapabilities = *c } - if capabilities == "all" { - capabilities = allCapabilities + if driverCapabilities == "all" { + driverCapabilities = allDriverCapabilities } requirements := getRequirements(env) @@ -230,10 +233,10 @@ func getNvidiaConfig(env map[string]string) *nvidiaConfig { disableRequire, _ := strconv.ParseBool(env[envNVDisableRequire]) return &nvidiaConfig{ - Devices: devices, - Capabilities: capabilities, - Requirements: requirements, - DisableRequire: disableRequire, + Devices: devices, + DriverCapabilities: driverCapabilities, + Requirements: requirements, + DisableRequire: disableRequire, } } diff --git a/nvidia-container-toolkit/main.go b/nvidia-container-toolkit/main.go index 1753688f..52f4c0ef 100644 --- a/nvidia-container-toolkit/main.go +++ b/nvidia-container-toolkit/main.go @@ -124,7 +124,7 @@ func doPrestart() { args = append(args, fmt.Sprintf("--device=%s", nvidia.Devices)) } - for _, cap := range strings.Split(nvidia.Capabilities, ",") { + for _, cap := range strings.Split(nvidia.DriverCapabilities, ",") { if len(cap) == 0 { break }