diff --git a/pkg/container_config.go b/pkg/container_config.go index 3d51698a..6517008e 100644 --- a/pkg/container_config.go +++ b/pkg/container_config.go @@ -197,7 +197,7 @@ func isLegacyCUDAImage(env map[string]string) bool { return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 } -func getDevices(env map[string]string, legacyImage bool) *string { +func getDevices(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool, legacyImage bool) *string { // Build a list of envvars to consider. envVars := []string{envNVVisibleDevices} if envSwarmGPU != nil { @@ -295,11 +295,11 @@ func getRequirements(env map[string]string, legacyImage bool) []string { return requirements } -func getNvidiaConfig(env map[string]string, mounts []Mount, privileged bool) *nvidiaConfig { +func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool) *nvidiaConfig { legacyImage := isLegacyCUDAImage(env) var devices string - if d := getDevices(env, legacyImage); d != nil { + if d := getDevices(hookConfig, env, mounts, privileged, legacyImage); d != nil { devices = *d } else { // 'nil' devices means this is not a GPU container. @@ -363,6 +363,6 @@ func getContainerConfig(hook HookConfig) (config containerConfig) { Pid: h.Pid, Rootfs: s.Root.Path, Env: env, - Nvidia: getNvidiaConfig(env, s.Mounts, privileged), + Nvidia: getNvidiaConfig(&hook, env, s.Mounts, privileged), } } diff --git a/pkg/container_test.go b/pkg/container_test.go index 9c197a30..3d11648f 100644 --- a/pkg/container_test.go +++ b/pkg/container_test.go @@ -407,7 +407,8 @@ func TestGetNvidiaConfigEnvvar(t *testing.T) { // Wrap the call to getNvidiaConfig() in a closure. var config *nvidiaConfig getConfig := func() { - config = getNvidiaConfig(tc.env, nil, tc.privileged) + hookConfig := getDefaultHookConfig() + config = getNvidiaConfig(&hookConfig, tc.env, nil, tc.privileged) } // For any tests that are expected to panic, make sure they do.