diff --git a/pkg/container_config.go b/pkg/container_config.go index 172a9a46..b7648ac7 100644 --- a/pkg/container_config.go +++ b/pkg/container_config.go @@ -192,6 +192,12 @@ func isPrivileged(s *Spec) bool { return false } +func isLegacyCUDAImage(env map[string]string) bool { + legacyCudaVersion := env[envCUDAVersion] + cudaRequire := env[envNVRequireCUDA] + return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 +} + func getDevices(env map[string]string) *string { gpuVars := []string{envNVVisibleDevices} if envSwarmGPU != nil { @@ -313,10 +319,7 @@ func getNvidiaConfigLegacy(env map[string]string, privileged bool) *nvidiaConfig } func getNvidiaConfig(env map[string]string, privileged bool) *nvidiaConfig { - legacyCudaVersion := env[envCUDAVersion] - cudaRequire := env[envNVRequireCUDA] - if len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 { - // Legacy CUDA image detected. + if isLegacyCUDAImage(env) { return getNvidiaConfigLegacy(env, privileged) }