mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 00:08:11 +00:00
[no-relnote] Use string slice for devices in hook
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
b077e2648d
commit
1991b3ef2a
@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"golang.org/x/mod/semver"
|
||||
@ -36,7 +35,7 @@ const (
|
||||
)
|
||||
|
||||
type nvidiaConfig struct {
|
||||
Devices string
|
||||
Devices []string
|
||||
MigConfigDevices string
|
||||
MigMonitorDevices string
|
||||
ImexChannels string
|
||||
@ -172,34 +171,19 @@ func isPrivileged(s *Spec) bool {
|
||||
return image.IsPrivileged(&fullSpec)
|
||||
}
|
||||
|
||||
func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) *string {
|
||||
func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) []string {
|
||||
// We check if the image has at least one of the Swarm resource envvars defined and use this
|
||||
// if specified.
|
||||
var hasSwarmEnvvar bool
|
||||
for _, envvar := range swarmResourceEnvvars {
|
||||
if image.HasEnvvar(envvar) {
|
||||
hasSwarmEnvvar = true
|
||||
break
|
||||
return image.DevicesFromEnvvars(swarmResourceEnvvars...).List()
|
||||
}
|
||||
}
|
||||
|
||||
var devices []string
|
||||
if hasSwarmEnvvar {
|
||||
devices = image.DevicesFromEnvvars(swarmResourceEnvvars...).List()
|
||||
} else {
|
||||
devices = image.DevicesFromEnvvars(envNVVisibleDevices).List()
|
||||
}
|
||||
|
||||
if len(devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
devicesString := strings.Join(devices, ",")
|
||||
|
||||
return &devicesString
|
||||
return image.DevicesFromEnvvars(envNVVisibleDevices).List()
|
||||
}
|
||||
|
||||
func getDevicesFromMounts(mounts []Mount) *string {
|
||||
func getDevicesFromMounts(mounts []Mount) []string {
|
||||
var devices []string
|
||||
for _, m := range mounts {
|
||||
root := filepath.Clean(deviceListAsVolumeMountsRoot)
|
||||
@ -232,22 +216,21 @@ func getDevicesFromMounts(mounts []Mount) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := strings.Join(devices, ",")
|
||||
return &ret
|
||||
return devices
|
||||
}
|
||||
|
||||
func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *string {
|
||||
func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) []string {
|
||||
// If enabled, try and get the device list from volume mounts first
|
||||
if hookConfig.AcceptDeviceListAsVolumeMounts {
|
||||
devices := getDevicesFromMounts(mounts)
|
||||
if devices != nil {
|
||||
if len(devices) > 0 {
|
||||
return devices
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to reading from the environment variable if privileges are correct
|
||||
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
|
||||
if devices == nil {
|
||||
if len(devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
if privileged || hookConfig.AcceptEnvvarUnprivileged {
|
||||
@ -314,11 +297,9 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo
|
||||
func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *nvidiaConfig {
|
||||
legacyImage := image.IsLegacy()
|
||||
|
||||
var devices string
|
||||
if d := getDevices(hookConfig, image, mounts, privileged); d != nil {
|
||||
devices = *d
|
||||
} else {
|
||||
// 'nil' devices means this is not a GPU container.
|
||||
devices := getDevices(hookConfig, image, mounts, privileged)
|
||||
if len(devices) == 0 {
|
||||
// empty devices means this is not a GPU container.
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@ -38,7 +37,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -51,7 +50,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -82,7 +81,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "",
|
||||
Devices: []string{""},
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -95,7 +94,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -109,7 +108,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -123,7 +122,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -137,7 +136,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -153,7 +152,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
||||
},
|
||||
@ -170,7 +169,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{},
|
||||
},
|
||||
@ -200,7 +199,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -231,7 +230,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "",
|
||||
Devices: []string{""},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -244,7 +243,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -258,7 +257,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -272,7 +271,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -286,7 +285,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
@ -302,7 +301,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
||||
},
|
||||
@ -319,7 +318,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
Devices: []string{"gpu0", "gpu1"},
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{},
|
||||
},
|
||||
@ -332,7 +331,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{},
|
||||
},
|
||||
@ -346,7 +345,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: true,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
MigConfigDevices: "mig0,mig1",
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
@ -371,7 +370,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: true,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
MigMonitorDevices: "mig0,mig1",
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
@ -398,7 +397,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
SupportedDriverCapabilities: "video,display",
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
DriverCapabilities: "display,video",
|
||||
},
|
||||
},
|
||||
@ -413,7 +412,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
SupportedDriverCapabilities: "video,display,compute,utility",
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
DriverCapabilities: "display,video",
|
||||
},
|
||||
},
|
||||
@ -427,7 +426,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
SupportedDriverCapabilities: "video,display,utility,compute",
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
Devices: []string{"all"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
},
|
||||
@ -443,7 +442,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
SupportedDriverCapabilities: "video,display,utility,compute",
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "GPU1,GPU2",
|
||||
Devices: []string{"GPU1", "GPU2"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
},
|
||||
@ -459,7 +458,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
SupportedDriverCapabilities: "video,display,utility,compute",
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "GPU1,GPU2",
|
||||
Devices: []string{"GPU1", "GPU2"},
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
},
|
||||
@ -511,7 +510,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
|
||||
var tests = []struct {
|
||||
description string
|
||||
mounts []Mount
|
||||
expectedDevices *string
|
||||
expectedDevices []string
|
||||
}{
|
||||
{
|
||||
description: "No mounts",
|
||||
@ -560,7 +559,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
|
||||
Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1"),
|
||||
},
|
||||
},
|
||||
expectedDevices: &[]string{"GPU0,GPU1"}[0],
|
||||
expectedDevices: []string{"GPU0", "GPU1"},
|
||||
},
|
||||
{
|
||||
description: "Discover 2 devices with slashes in the name",
|
||||
@ -574,7 +573,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
|
||||
Destination: filepath.Join(deviceListAsVolumeMountsRoot, "GPU1-MIG0/0/1"),
|
||||
},
|
||||
},
|
||||
expectedDevices: &[]string{"GPU0-MIG0/0/1,GPU1-MIG0/0/1"}[0],
|
||||
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
@ -593,7 +592,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
||||
privileged bool
|
||||
acceptUnprivileged bool
|
||||
acceptMounts bool
|
||||
expectedDevices *string
|
||||
expectedDevices []string
|
||||
}{
|
||||
{
|
||||
description: "Mount devices, unprivileged, no accept unprivileged",
|
||||
@ -611,7 +610,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
||||
privileged: false,
|
||||
acceptUnprivileged: false,
|
||||
acceptMounts: true,
|
||||
expectedDevices: &[]string{"GPU0,GPU1"}[0],
|
||||
expectedDevices: []string{"GPU0", "GPU1"},
|
||||
},
|
||||
{
|
||||
description: "No mount devices, unprivileged, no accept unprivileged",
|
||||
@ -629,7 +628,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
||||
privileged: true,
|
||||
acceptUnprivileged: false,
|
||||
acceptMounts: true,
|
||||
expectedDevices: &[]string{"GPU0,GPU1"}[0],
|
||||
expectedDevices: []string{"GPU0", "GPU1"},
|
||||
},
|
||||
{
|
||||
description: "No mount devices, unprivileged, accept unprivileged",
|
||||
@ -638,7 +637,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
||||
privileged: false,
|
||||
acceptUnprivileged: true,
|
||||
acceptMounts: true,
|
||||
expectedDevices: &[]string{"GPU0,GPU1"}[0],
|
||||
expectedDevices: []string{"GPU0", "GPU1"},
|
||||
},
|
||||
{
|
||||
description: "Mount devices, unprivileged, accept unprivileged, no accept mounts",
|
||||
@ -656,7 +655,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
||||
privileged: false,
|
||||
acceptUnprivileged: true,
|
||||
acceptMounts: false,
|
||||
expectedDevices: &[]string{"GPU2,GPU3"}[0],
|
||||
expectedDevices: []string{"GPU2", "GPU3"},
|
||||
},
|
||||
{
|
||||
description: "Mount devices, unprivileged, no accept unprivileged, no accept mounts",
|
||||
@ -680,7 +679,7 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
// Wrap the call to getDevices() in a closure.
|
||||
var devices *string
|
||||
var devices []string
|
||||
getDevices := func() {
|
||||
image, _ := image.New(
|
||||
image.WithEnvMap(
|
||||
@ -704,8 +703,6 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
all := "all"
|
||||
empty := ""
|
||||
envDockerResourceGPUs := "DOCKER_RESOURCE_GPUS"
|
||||
gpuID := "GPU-12345"
|
||||
anotherGPUID := "GPU-67890"
|
||||
@ -715,7 +712,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
description string
|
||||
swarmResourceEnvvars []string
|
||||
env map[string]string
|
||||
expectedDevices *string
|
||||
expectedDevices []string
|
||||
}{
|
||||
{
|
||||
description: "empty env returns nil for non-legacy image",
|
||||
@ -737,14 +734,14 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
env: map[string]string{
|
||||
envNVVisibleDevices: "none",
|
||||
},
|
||||
expectedDevices: &empty,
|
||||
expectedDevices: []string{""},
|
||||
},
|
||||
{
|
||||
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
|
||||
env: map[string]string{
|
||||
envNVVisibleDevices: gpuID,
|
||||
},
|
||||
expectedDevices: &gpuID,
|
||||
expectedDevices: []string{gpuID},
|
||||
},
|
||||
{
|
||||
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
|
||||
@ -752,14 +749,14 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envNVVisibleDevices: gpuID,
|
||||
envCUDAVersion: "legacy",
|
||||
},
|
||||
expectedDevices: &gpuID,
|
||||
expectedDevices: []string{gpuID},
|
||||
},
|
||||
{
|
||||
description: "empty env returns all for legacy image",
|
||||
env: map[string]string{
|
||||
envCUDAVersion: "legacy",
|
||||
},
|
||||
expectedDevices: &all,
|
||||
expectedDevices: []string{"all"},
|
||||
},
|
||||
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is ignored when
|
||||
// not enabled
|
||||
@ -789,7 +786,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envNVVisibleDevices: "none",
|
||||
envDockerResourceGPUs: anotherGPUID,
|
||||
},
|
||||
expectedDevices: &empty,
|
||||
expectedDevices: []string{""},
|
||||
},
|
||||
{
|
||||
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
|
||||
@ -797,7 +794,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envNVVisibleDevices: gpuID,
|
||||
envDockerResourceGPUs: anotherGPUID,
|
||||
},
|
||||
expectedDevices: &gpuID,
|
||||
expectedDevices: []string{gpuID},
|
||||
},
|
||||
{
|
||||
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
|
||||
@ -806,7 +803,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envDockerResourceGPUs: anotherGPUID,
|
||||
envCUDAVersion: "legacy",
|
||||
},
|
||||
expectedDevices: &gpuID,
|
||||
expectedDevices: []string{gpuID},
|
||||
},
|
||||
{
|
||||
description: "empty env returns all for legacy image",
|
||||
@ -814,7 +811,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envDockerResourceGPUs: anotherGPUID,
|
||||
envCUDAVersion: "legacy",
|
||||
},
|
||||
expectedDevices: &all,
|
||||
expectedDevices: []string{"all"},
|
||||
},
|
||||
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is selected when
|
||||
// enabled
|
||||
@ -842,7 +839,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
env: map[string]string{
|
||||
envDockerResourceGPUs: "none",
|
||||
},
|
||||
expectedDevices: &empty,
|
||||
expectedDevices: []string{""},
|
||||
},
|
||||
{
|
||||
description: "DOCKER_RESOURCE_GPUS set returns value for non-legacy image",
|
||||
@ -850,7 +847,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
env: map[string]string{
|
||||
envDockerResourceGPUs: gpuID,
|
||||
},
|
||||
expectedDevices: &gpuID,
|
||||
expectedDevices: []string{gpuID},
|
||||
},
|
||||
{
|
||||
description: "DOCKER_RESOURCE_GPUS set returns value for legacy image",
|
||||
@ -859,7 +856,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envDockerResourceGPUs: gpuID,
|
||||
envCUDAVersion: "legacy",
|
||||
},
|
||||
expectedDevices: &gpuID,
|
||||
expectedDevices: []string{gpuID},
|
||||
},
|
||||
{
|
||||
description: "DOCKER_RESOURCE_GPUS is selected if present",
|
||||
@ -867,7 +864,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
env: map[string]string{
|
||||
envDockerResourceGPUs: anotherGPUID,
|
||||
},
|
||||
expectedDevices: &anotherGPUID,
|
||||
expectedDevices: []string{anotherGPUID},
|
||||
},
|
||||
{
|
||||
description: "DOCKER_RESOURCE_GPUS overrides NVIDIA_VISIBLE_DEVICES if present",
|
||||
@ -876,7 +873,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envNVVisibleDevices: gpuID,
|
||||
envDockerResourceGPUs: anotherGPUID,
|
||||
},
|
||||
expectedDevices: &anotherGPUID,
|
||||
expectedDevices: []string{anotherGPUID},
|
||||
},
|
||||
{
|
||||
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL overrides NVIDIA_VISIBLE_DEVICES if present",
|
||||
@ -885,7 +882,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envNVVisibleDevices: gpuID,
|
||||
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
|
||||
},
|
||||
expectedDevices: &anotherGPUID,
|
||||
expectedDevices: []string{anotherGPUID},
|
||||
},
|
||||
{
|
||||
description: "All available swarm resource envvars are selected and override NVIDIA_VISIBLE_DEVICES if present",
|
||||
@ -895,10 +892,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
"DOCKER_RESOURCE_GPUS": thirdGPUID,
|
||||
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
|
||||
},
|
||||
expectedDevices: func() *string {
|
||||
result := fmt.Sprintf("%s,%s", thirdGPUID, anotherGPUID)
|
||||
return &result
|
||||
}(),
|
||||
expectedDevices: []string{thirdGPUID, anotherGPUID},
|
||||
},
|
||||
{
|
||||
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS override NVIDIA_VISIBLE_DEVICES if present",
|
||||
@ -907,23 +901,17 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
envNVVisibleDevices: gpuID,
|
||||
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
|
||||
},
|
||||
expectedDevices: &anotherGPUID,
|
||||
expectedDevices: []string{anotherGPUID},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
image, _ := image.New(
|
||||
image.WithEnvMap(tc.env),
|
||||
)
|
||||
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars)
|
||||
if tc.expectedDevices == nil {
|
||||
require.Nil(t, devices, "%d: %v", i, tc)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, devices, "%d: %v", i, tc)
|
||||
require.Equal(t, *tc.expectedDevices, *devices, "%d: %v", i, tc)
|
||||
require.EqualValues(t, tc.expectedDevices, devices)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -120,8 +120,8 @@ func doPrestart() {
|
||||
if cli.NoCgroups {
|
||||
args = append(args, "--no-cgroups")
|
||||
}
|
||||
if len(nvidia.Devices) > 0 {
|
||||
args = append(args, fmt.Sprintf("--device=%s", nvidia.Devices))
|
||||
if devicesString := strings.Join(nvidia.Devices, ","); len(devicesString) > 0 {
|
||||
args = append(args, fmt.Sprintf("--device=%s", devicesString))
|
||||
}
|
||||
if len(nvidia.MigConfigDevices) > 0 {
|
||||
args = append(args, fmt.Sprintf("--mig-config=%s", nvidia.MigConfigDevices))
|
||||
|
Loading…
Reference in New Issue
Block a user