Push HookConfig and privileged flags down to getDevices() call

Signed-off-by: Kevin Klues <kklues@nvidia.com>
This commit is contained in:
Kevin Klues 2020-07-23 14:13:33 +00:00
parent 2ae7cb07cf
commit aec9a28bc3
2 changed files with 6 additions and 5 deletions

View File

@ -197,7 +197,7 @@ func isLegacyCUDAImage(env map[string]string) bool {
return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0
} }
func getDevices(env map[string]string, legacyImage bool) *string { func getDevices(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool, legacyImage bool) *string {
// Build a list of envvars to consider. // Build a list of envvars to consider.
envVars := []string{envNVVisibleDevices} envVars := []string{envNVVisibleDevices}
if envSwarmGPU != nil { if envSwarmGPU != nil {
@ -295,11 +295,11 @@ func getRequirements(env map[string]string, legacyImage bool) []string {
return requirements return requirements
} }
func getNvidiaConfig(env map[string]string, mounts []Mount, privileged bool) *nvidiaConfig { func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool) *nvidiaConfig {
legacyImage := isLegacyCUDAImage(env) legacyImage := isLegacyCUDAImage(env)
var devices string var devices string
if d := getDevices(env, legacyImage); d != nil { if d := getDevices(hookConfig, env, mounts, privileged, legacyImage); d != nil {
devices = *d devices = *d
} else { } else {
// 'nil' devices means this is not a GPU container. // 'nil' devices means this is not a GPU container.
@ -363,6 +363,6 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
Pid: h.Pid, Pid: h.Pid,
Rootfs: s.Root.Path, Rootfs: s.Root.Path,
Env: env, Env: env,
Nvidia: getNvidiaConfig(env, s.Mounts, privileged), Nvidia: getNvidiaConfig(&hook, env, s.Mounts, privileged),
} }
} }

View File

@ -407,7 +407,8 @@ func TestGetNvidiaConfigEnvvar(t *testing.T) {
// Wrap the call to getNvidiaConfig() in a closure. // Wrap the call to getNvidiaConfig() in a closure.
var config *nvidiaConfig var config *nvidiaConfig
getConfig := func() { getConfig := func() {
config = getNvidiaConfig(tc.env, nil, tc.privileged) hookConfig := getDefaultHookConfig()
config = getNvidiaConfig(&hookConfig, tc.env, nil, tc.privileged)
} }
// For any tests that are expected to panic, make sure they do. // For any tests that are expected to panic, make sure they do.