diff --git a/pkg/container_config.go b/pkg/container_config.go index ebae74f5..15bbc0f4 100644 --- a/pkg/container_config.go +++ b/pkg/container_config.go @@ -198,19 +198,42 @@ func isLegacyCUDAImage(env map[string]string) bool { return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 } -func getDevices(env map[string]string) *string { - gpuVars := []string{envNVVisibleDevices} +func getDevices(env map[string]string, legacyImage bool) *string { + // Build a list of envvars to consider. + envVars := []string{envNVVisibleDevices} if envSwarmGPU != nil { - // The Swarm resource has higher precedence. - gpuVars = append([]string{*envSwarmGPU}, gpuVars...) + // The Swarm envvar has higher precedence. + envVars = append([]string{*envSwarmGPU}, envVars...) } - for _, gpuVar := range gpuVars { - if devices, ok := env[gpuVar]; ok { - return &devices + // Grab a reference to devices from the first envvar + // in the list that actually exists in the environment. + var devices *string + for _, envVar := range envVars { + if devs, ok := env[envVar]; ok { + devices = &devs } } - return nil + + // Environment variable unset with legacy image: default to "all". + if devices == nil && legacyImage { + all := "all" + return &all + } + + // Environment variable unset or empty or "void": return nil + if devices == nil || len(*devices) == 0 || *devices == "void" { + return nil + } + + // Environment variable set to "none": reset to "". + if *devices == "none" { + empty := "" + return &empty + } + + // Any other value. + return devices } func getMigConfigDevices(env map[string]string) *string { @@ -227,14 +250,37 @@ func getMigMonitorDevices(env map[string]string) *string { return nil } -func getDriverCapabilities(env map[string]string) *string { - if capabilities, ok := env[envNVDriverCapabilities]; ok { - return &capabilities +func getDriverCapabilities(env map[string]string, legacyImage bool) *string { + // Grab a reference to the capabilities from the envvar + // if it actually exists in the environment. + var capabilities *string + if caps, ok := env[envNVDriverCapabilities]; ok { + capabilities = &caps } - return nil + + // Environment variable unset with legacy image: set all capabilities. + if capabilities == nil && legacyImage { + allCaps := allDriverCapabilities + return &allCaps + } + + // Environment variable unset or set but empty: set default capabilities. + if capabilities == nil || len(*capabilities) == 0 { + defaultCaps := defaultDriverCapabilities + return &defaultCaps + } + + // Environment variable set to "all": set all capabilities. + if *capabilities == "all" { + allCaps := allDriverCapabilities + return &allCaps + } + + // Any other value + return capabilities } -func getRequirements(env map[string]string) []string { +func getRequirements(env map[string]string, legacyImage bool) []string { // All variables with the "NVIDIA_REQUIRE_" prefix are passed to nvidia-container-cli var requirements []string for name, value := range env { @@ -242,24 +288,25 @@ func getRequirements(env map[string]string) []string { requirements = append(requirements, value) } } + if legacyImage { + vmaj, vmin, _ := parseCudaVersion(env[envCUDAVersion]) + cudaRequire := fmt.Sprintf("cuda>=%d.%d", vmaj, vmin) + requirements = append(requirements, cudaRequire) + } return requirements } func getNvidiaConfig(env map[string]string, privileged bool) *nvidiaConfig { + legacyImage := isLegacyCUDAImage(env) + var devices string - d := getDevices(env) - if d == nil || len(*d) == 0 || *d == "void" { - // Environment variable unset or empty or "void": not a GPU container. + if d := getDevices(env, legacyImage); d != nil { + devices = *d + } else { + // 'nil' devices means this is not a GPU container. return nil } - // Environment variable non-empty and not "void". - devices = *d - - if devices == "none" { - devices = "" - } - var migConfigDevices string if d := getMigConfigDevices(env); d != nil { migConfigDevices = *d @@ -277,18 +324,11 @@ func getNvidiaConfig(env map[string]string, privileged bool) *nvidiaConfig { } var driverCapabilities string - if c := getDriverCapabilities(env); c == nil || len(*c) == 0 { - // Environment variable unset or set but empty: use default capability. - driverCapabilities = defaultDriverCapabilities - } else { - // Environment variable set and non-empty. + if c := getDriverCapabilities(env, legacyImage); c != nil { driverCapabilities = *c } - if driverCapabilities == "all" { - driverCapabilities = allDriverCapabilities - } - requirements := getRequirements(env) + requirements := getRequirements(env, legacyImage) // Don't fail on invalid values. disableRequire, _ := strconv.ParseBool(env[envNVDisableRequire])