diff --git a/pkg/container_config.go b/pkg/container_config.go index 6517008e..b7f54e0e 100644 --- a/pkg/container_config.go +++ b/pkg/container_config.go @@ -197,7 +197,7 @@ func isLegacyCUDAImage(env map[string]string) bool { return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 } -func getDevices(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool, legacyImage bool) *string { +func getDevicesFromEnvvar(env map[string]string, legacyImage bool) *string { // Build a list of envvars to consider. envVars := []string{envNVVisibleDevices} if envSwarmGPU != nil { @@ -235,6 +235,10 @@ func getDevices(hookConfig *HookConfig, env map[string]string, mounts []Mount, p return devices } +func getDevices(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool, legacyImage bool) *string { + return getDevicesFromEnvvar(env, legacyImage) +} + func getMigConfigDevices(env map[string]string) *string { if devices, ok := env[envNVMigConfigDevices]; ok { return &devices diff --git a/pkg/container_test.go b/pkg/container_test.go index 3d11648f..c4748b6b 100644 --- a/pkg/container_test.go +++ b/pkg/container_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestGetNvidiaConfigEnvvar(t *testing.T) { +func TestGetNvidiaConfig(t *testing.T) { var tests = []struct { description string env map[string]string