diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index 42732bfd..f3562247 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -174,7 +174,7 @@ func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) *stri // if specified. var hasSwarmEnvvar bool for _, envvar := range swarmResourceEnvvars { - if _, exists := image[envvar]; exists { + if image.HasEnvvar(envvar) { hasSwarmEnvvar = true break } @@ -257,28 +257,31 @@ func getDevices(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privil return nil } -func getMigConfigDevices(env map[string]string) *string { - if devices, ok := env[envNVMigConfigDevices]; ok { - return &devices - } - return nil +func getMigConfigDevices(image image.CUDA) *string { + return getMigDevices(image, envNVMigConfigDevices) } -func getMigMonitorDevices(env map[string]string) *string { - if devices, ok := env[envNVMigMonitorDevices]; ok { - return &devices - } - return nil +func getMigMonitorDevices(image image.CUDA) *string { + return getMigDevices(image, envNVMigMonitorDevices) } -func (c *HookConfig) getDriverCapabilities(env map[string]string, legacyImage bool) image.DriverCapabilities { +func getMigDevices(image image.CUDA, envvar string) *string { + if !image.HasEnvvar(envvar) { + return nil + } + devices := image.Getenv(envvar) + return &devices +} + +func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities { // We use the default driver capabilities by default. This is filtered to only include the // supported capabilities supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities) capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) - capsEnv, capsEnvSpecified := env[envNVDriverCapabilities] + capsEnvSpecified := cudaImage.HasEnvvar(envNVDriverCapabilities) + capsEnv := cudaImage.Getenv(envNVDriverCapabilities) if !capsEnvSpecified && legacyImage { // Environment variable unset with legacy image: set all capabilities. diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index 2e5ec98b..2e2bd632 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -465,6 +465,9 @@ func TestGetNvidiaConfig(t *testing.T) { } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { + image, _ := image.New( + image.WithEnvMap(tc.env), + ) // Wrap the call to getNvidiaConfig() in a closure. var config *nvidiaConfig getConfig := func() { @@ -473,7 +476,7 @@ func TestGetNvidiaConfig(t *testing.T) { defaultConfig, _ := getDefaultHookConfig() hookConfig = &defaultConfig } - config = getNvidiaConfig(hookConfig, tc.env, nil, tc.privileged) + config = getNvidiaConfig(hookConfig, image, nil, tc.privileged) } // For any tests that are expected to panic, make sure they do. @@ -678,13 +681,17 @@ func TestDeviceListSourcePriority(t *testing.T) { // Wrap the call to getDevices() in a closure. var devices *string getDevices := func() { - env := map[string]string{ - envNVVisibleDevices: tc.envvarDevices, - } + image, _ := image.New( + image.WithEnvMap( + map[string]string{ + envNVVisibleDevices: tc.envvarDevices, + }, + ), + ) hookConfig, _ := getDefaultHookConfig() hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged hookConfig.AcceptDeviceListAsVolumeMounts = tc.acceptMounts - devices = getDevices(&hookConfig, env, tc.mountDevices, tc.privileged) + devices = getDevices(&hookConfig, image, tc.mountDevices, tc.privileged) } // For all other tests, just grab the devices and check the results @@ -905,7 +912,10 @@ func TestGetDevicesFromEnvvar(t *testing.T) { for i, tc := range tests { t.Run(tc.description, func(t *testing.T) { - devices := getDevicesFromEnvvar(image.CUDA(tc.env), tc.swarmResourceEnvvars) + image, _ := image.New( + image.WithEnvMap(tc.env), + ) + devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars) if tc.expectedDevices == nil { require.Nil(t, devices, "%d: %v", i, tc) return @@ -1021,8 +1031,11 @@ func TestGetDriverCapabilities(t *testing.T) { SupportedDriverCapabilities: tc.supportedCapabilities, } + image, _ := image.New( + image.WithEnvMap(tc.env), + ) getDriverCapabilities := func() { - capabilities = c.getDriverCapabilities(tc.env, tc.legacyImage).String() + capabilities = c.getDriverCapabilities(image, tc.legacyImage).String() } if tc.expectedPanic { diff --git a/internal/config/image/builder.go b/internal/config/image/builder.go index b11f6c74..da9025f0 100644 --- a/internal/config/image/builder.go +++ b/internal/config/image/builder.go @@ -19,10 +19,13 @@ package image import ( "fmt" "strings" + + "github.com/opencontainers/runtime-spec/specs-go" ) type builder struct { - env []string + env map[string]string + mounts []specs.Mount disableRequire bool } @@ -30,7 +33,12 @@ type builder struct { func New(opt ...Option) (CUDA, error) { b := &builder{} for _, o := range opt { - o(b) + if err := o(b); err != nil { + return CUDA{}, err + } + } + if b.env == nil { + b.env = make(map[string]string) } return b.build() @@ -38,36 +46,57 @@ func New(opt ...Option) (CUDA, error) { // build creates a CUDA image from the builder. func (b builder) build() (CUDA, error) { - c := make(CUDA) - - for _, e := range b.env { - parts := strings.SplitN(e, "=", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid environment variable: %v", e) - } - c[parts[0]] = parts[1] - } - if b.disableRequire { - c[envNVDisableRequire] = "true" + b.env[envNVDisableRequire] = "true" } + c := CUDA{ + env: b.env, + mounts: b.mounts, + } return c, nil } // Option is a functional option for creating a CUDA image. -type Option func(*builder) +type Option func(*builder) error // WithDisableRequire sets the disable require option. func WithDisableRequire(disableRequire bool) Option { - return func(b *builder) { + return func(b *builder) error { b.disableRequire = disableRequire + return nil } } // WithEnv sets the environment variables to use when creating the CUDA image. +// Note that this also overwrites the values set with WithEnvMap. func WithEnv(env []string) Option { - return func(b *builder) { - b.env = env + return func(b *builder) error { + envmap := make(map[string]string) + for _, e := range env { + parts := strings.SplitN(e, "=", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid environment variable: %v", e) + } + envmap[parts[0]] = parts[1] + } + return WithEnvMap(envmap)(b) + } +} + +// WithEnvMap sets the environment variable map to use when creating the CUDA image. +// Note that this also overwrites the values set with WithEnv. +func WithEnvMap(env map[string]string) Option { + return func(b *builder) error { + b.env = env + return nil + } +} + +// WithMounts sets the mounts associated with the CUDA image. +func WithMounts(mounts []specs.Mount) Option { + return func(b *builder) error { + b.mounts = mounts + return nil } } diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index ee264001..2356fb31 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -18,9 +18,11 @@ package image import ( "fmt" + "path/filepath" "strconv" "strings" + "github.com/container-orchestrated-devices/container-device-interface/pkg/parser" "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/mod/semver" ) @@ -37,7 +39,10 @@ const ( // CUDA represents a CUDA image that can be used for GPU computing. This wraps // a map of environment variable to values that can be used to perform lookups // such as requirements. -type CUDA map[string]string +type CUDA struct { + env map[string]string + mounts []specs.Mount +} // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec. // The process environment is read (if present) to construc the CUDA Image. @@ -47,7 +52,10 @@ func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) { env = spec.Process.Env } - return New(WithEnv(env)) + return New( + WithEnv(env), + WithMounts(spec.Mounts), + ) } // NewCUDAImageFromEnv creates a CUDA image from the input environment. The environment @@ -56,12 +64,24 @@ func NewCUDAImageFromEnv(env []string) (CUDA, error) { return New(WithEnv(env)) } +// Getenv returns the value of the specified environment variable. +// If the environment variable is not specified, an empty string is returned. +func (i CUDA) Getenv(key string) string { + return i.env[key] +} + +// HasEnvvar checks whether the specified envvar is defined in the image. +func (i CUDA) HasEnvvar(key string) bool { + _, exists := i.env[key] + return exists +} + // IsLegacy returns whether the associated CUDA image is a "legacy" image. An // image is considered legacy if it has a CUDA_VERSION environment variable defined // and no NVIDIA_REQUIRE_CUDA environment variable defined. func (i CUDA) IsLegacy() bool { - legacyCudaVersion := i[envCUDAVersion] - cudaRequire := i[envNVRequireCUDA] + legacyCudaVersion := i.env[envCUDAVersion] + cudaRequire := i.env[envNVRequireCUDA] return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 } @@ -74,7 +94,7 @@ func (i CUDA) GetRequirements() ([]string, error) { // All variables with the "NVIDIA_REQUIRE_" prefix are passed to nvidia-container-cli var requirements []string - for name, value := range i { + for name, value := range i.env { if strings.HasPrefix(name, envNVRequirePrefix) && !strings.HasPrefix(name, envNVRequireJetpack) { requirements = append(requirements, value) } @@ -93,7 +113,7 @@ func (i CUDA) GetRequirements() ([]string, error) { // HasDisableRequire checks for the value of the NVIDIA_DISABLE_REQUIRE. If set // to a valid (true) boolean value this can be used to disable the requirement checks func (i CUDA) HasDisableRequire() bool { - if disable, exists := i[envNVDisableRequire]; exists { + if disable, exists := i.env[envNVDisableRequire]; exists { // i.logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", disable) d, _ := strconv.ParseBool(disable) return d @@ -104,12 +124,12 @@ func (i CUDA) HasDisableRequire() bool { // DevicesFromEnvvars returns the devices requested by the image through environment variables func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices { - // We concantenate all the devices from the specified envvars. + // We concantenate all the devices from the specified env. var isSet bool var devices []string requested := make(map[string]bool) for _, envVar := range envVars { - if devs, ok := i[envVar]; ok { + if devs, ok := i.env[envVar]; ok { isSet = true for _, d := range strings.Split(devs, ",") { trimmed := strings.TrimSpace(d) @@ -137,7 +157,7 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices { // GetDriverCapabilities returns the requested driver capabilities. func (i CUDA) GetDriverCapabilities() DriverCapabilities { - env := i[envNVDriverCapabilities] + env := i.env[envNVDriverCapabilities] capabilities := make(DriverCapabilities) for _, c := range strings.Split(env, ",") { @@ -148,7 +168,7 @@ func (i CUDA) GetDriverCapabilities() DriverCapabilities { } func (i CUDA) legacyVersion() (string, error) { - cudaVersion := i[envCUDAVersion] + cudaVersion := i.env[envCUDAVersion] majorMinor, err := parseMajorMinorVersion(cudaVersion) if err != nil { return "", fmt.Errorf("invalid CUDA version %v: %v", cudaVersion, err) @@ -178,3 +198,79 @@ func parseMajorMinorVersion(version string) (string, error) { } return majorMinor, nil } + +// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/ +func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { + var hasCDIdevice bool + for _, device := range i.DevicesFromEnvvars("NVIDIA_VISIBLE_DEVICES").List() { + if !parser.IsQualifiedName(device) { + return false + } + hasCDIdevice = true + } + + for _, device := range i.DevicesFromMounts() { + if !strings.HasPrefix(device, "cdi/") { + return false + } + hasCDIdevice = true + } + return hasCDIdevice +} + +const ( + deviceListAsVolumeMountsRoot = "/var/run/nvidia-container-devices" +) + +// DevicesFromMounts returns a list of device specified as mounts. +// TODO: This should be merged with getDevicesFromMounts used in the NVIDIA Container Runtime +func (i CUDA) DevicesFromMounts() []string { + root := filepath.Clean(deviceListAsVolumeMountsRoot) + seen := make(map[string]bool) + var devices []string + for _, m := range i.mounts { + source := filepath.Clean(m.Source) + // Only consider mounts who's host volume is /dev/null + if source != "/dev/null" { + continue + } + + destination := filepath.Clean(m.Destination) + if seen[destination] { + continue + } + seen[destination] = true + + // Only consider container mount points that begin with 'root' + if !strings.HasPrefix(destination, root) { + continue + } + + // Grab the full path beyond 'root' and add it to the list of devices + device := strings.Trim(strings.TrimPrefix(destination, root), "/") + if len(device) == 0 { + continue + } + devices = append(devices, device) + } + return devices +} + +// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image. +func (i CUDA) CDIDevicesFromMounts() []string { + var devices []string + for _, mountDevice := range i.DevicesFromMounts() { + if !strings.HasPrefix(mountDevice, "cdi/") { + continue + } + parts := strings.SplitN(strings.TrimPrefix(mountDevice, "cdi/"), "/", 3) + if len(parts) != 3 { + continue + } + vendor := parts[0] + class := parts[1] + device := parts[2] + devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device)) + } + return devices +} diff --git a/internal/config/image/cuda_image_test.go b/internal/config/image/cuda_image_test.go index 371a0ccc..406ad443 100644 --- a/internal/config/image/cuda_image_test.go +++ b/internal/config/image/cuda_image_test.go @@ -126,7 +126,6 @@ func TestGetRequirements(t *testing.T) { requirements, err := image.GetRequirements() require.NoError(t, err) require.ElementsMatch(t, tc.requirements, requirements) - }) } diff --git a/internal/info/auto.go b/internal/info/auto.go index 760d33d9..c089b6c2 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -19,7 +19,6 @@ package info import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" @@ -69,7 +68,7 @@ func (r resolver) resolveMode(mode string, image image.CUDA) (rmode string) { r.logger.Infof("Auto-detected mode as '%v'", rmode) }() - if onlyFullyQualifiedCDIDevices(image) { + if image.OnlyFullyQualifiedCDIDevices() { return "cdi" } @@ -88,14 +87,3 @@ func (r resolver) resolveMode(mode string, image image.CUDA) (rmode string) { return "legacy" } - -func onlyFullyQualifiedCDIDevices(image image.CUDA) bool { - var hasCDIdevice bool - for _, device := range image.DevicesFromEnvvars("NVIDIA_VISIBLE_DEVICES").List() { - if !cdi.IsQualifiedName(device) { - return false - } - hasCDIdevice = true - } - return hasCDIdevice -} diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index f91682f6..fb845d78 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/opencontainers/runtime-spec/specs-go" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) @@ -32,7 +33,8 @@ func TestResolveAutoMode(t *testing.T) { mode string expectedMode string info map[string]bool - image image.CUDA + envmap map[string]string + mounts []string }{ { description: "non-auto resolves to input", @@ -119,7 +121,7 @@ func TestResolveAutoMode(t *testing.T) { description: "cdi devices resolves to cdi", mode: "auto", expectedMode: "cdi", - image: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=all", }, }, @@ -127,14 +129,14 @@ func TestResolveAutoMode(t *testing.T) { description: "multiple cdi devices resolves to cdi", mode: "auto", expectedMode: "cdi", - image: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,nvidia.com/gpu=1", }, }, { description: "at least one non-cdi device resolves to legacy", mode: "auto", - image: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0", }, info: map[string]bool{ @@ -147,7 +149,7 @@ func TestResolveAutoMode(t *testing.T) { { description: "at least one non-cdi device resolves to csv", mode: "auto", - image: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0", }, info: map[string]bool{ @@ -157,6 +159,44 @@ func TestResolveAutoMode(t *testing.T) { }, expectedMode: "csv", }, + { + description: "cdi mount devices resolves to CDI", + mode: "auto", + mounts: []string{ + "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0", + }, + expectedMode: "cdi", + }, + { + description: "cdi mount and non-CDI devices resolves to legacy", + mode: "auto", + mounts: []string{ + "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0", + "/var/run/nvidia-container-devices/all", + }, + info: map[string]bool{ + "nvml": true, + "tegra": false, + "nvgpu": false, + }, + expectedMode: "legacy", + }, + { + description: "cdi mount and non-CDI envvar resolves to legacy", + mode: "auto", + envmap: map[string]string{ + "NVIDIA_VISIBLE_DEVICES": "0", + }, + mounts: []string{ + "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0", + }, + info: map[string]bool{ + "nvml": true, + "tegra": false, + "nvgpu": false, + }, + expectedMode: "legacy", + }, } for _, tc := range testCases { @@ -177,7 +217,20 @@ func TestResolveAutoMode(t *testing.T) { logger: logger, info: info, } - mode := r.resolveMode(tc.mode, tc.image) + + var mounts []specs.Mount + for _, d := range tc.mounts { + mount := specs.Mount{ + Source: "/dev/null", + Destination: d, + } + mounts = append(mounts, mount) + } + image, _ := image.New( + image.WithEnvMap(tc.envmap), + image.WithMounts(mounts), + ) + mode := r.resolveMode(tc.mode, image) require.EqualValues(t, tc.expectedMode, mode) }) } diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 7cef3672..89a66b9a 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -67,6 +67,13 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C if err != nil { return nil, err } + if cfg.AcceptDeviceListAsVolumeMounts { + mountDevices := container.CDIDevicesFromMounts() + if len(mountDevices) > 0 { + return mountDevices, nil + } + } + envDevices := container.DevicesFromEnvvars(visibleDevicesEnvvar) var devices []string diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index 53579268..54304428 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -55,7 +55,7 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, image image.CUD return nil, fmt.Errorf("failed to get list of CSV files: %v", err) } - if nvidiaRequireJetpack := image[nvidiaRequireJetpackEnvvar]; nvidiaRequireJetpack != "csv-mounts=all" { + if image.Getenv(nvidiaRequireJetpackEnvvar) != "csv-mounts=all" { csvFiles = csv.BaseFilesOnly(csvFiles) } diff --git a/internal/modifier/csv_test.go b/internal/modifier/csv_test.go index 3cd6ce56..32fd097c 100644 --- a/internal/modifier/csv_test.go +++ b/internal/modifier/csv_test.go @@ -32,30 +32,33 @@ func TestNewCSVModifier(t *testing.T) { testCases := []struct { description string cfg *config.Config - image image.CUDA + envmap map[string]string expectedError error expectedNil bool }{ { description: "visible devices not set returns nil", - image: image.CUDA{}, + envmap: map[string]string{}, expectedNil: true, }, { description: "visible devices empty returns nil", - image: image.CUDA{"NVIDIA_VISIBLE_DEVICES": ""}, + envmap: map[string]string{"NVIDIA_VISIBLE_DEVICES": ""}, expectedNil: true, }, { description: "visible devices 'void' returns nil", - image: image.CUDA{"NVIDIA_VISIBLE_DEVICES": "void"}, + envmap: map[string]string{"NVIDIA_VISIBLE_DEVICES": "void"}, expectedNil: true, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - m, err := NewCSVModifier(logger, tc.cfg, tc.image) + image, _ := image.New( + image.WithEnvMap(tc.envmap), + ) + m, err := NewCSVModifier(logger, tc.cfg, image) if tc.expectedError != nil { require.Error(t, err) } else { diff --git a/internal/modifier/gds.go b/internal/modifier/gds.go index 510051ab..dd03a731 100644 --- a/internal/modifier/gds.go +++ b/internal/modifier/gds.go @@ -38,7 +38,7 @@ func NewGDSModifier(logger logger.Interface, cfg *config.Config, image image.CUD return nil, nil } - if gds := image[nvidiaGDSEnvvar]; gds != "enabled" { + if image.Getenv(nvidiaGDSEnvvar) != "enabled" { return nil, nil } diff --git a/internal/modifier/graphics_test.go b/internal/modifier/graphics_test.go index e062763d..163f3628 100644 --- a/internal/modifier/graphics_test.go +++ b/internal/modifier/graphics_test.go @@ -26,7 +26,7 @@ import ( func TestGraphicsModifier(t *testing.T) { testCases := []struct { description string - cudaImage image.CUDA + envmap map[string]string expectedRequired bool }{ { @@ -34,20 +34,20 @@ func TestGraphicsModifier(t *testing.T) { }, { description: "devices with no capabilities does not create modifier", - cudaImage: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "all", }, }, { description: "devices with no non-graphics does not create modifier", - cudaImage: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "compute", }, }, { description: "devices with all capabilities creates modifier", - cudaImage: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "all", }, @@ -55,7 +55,7 @@ func TestGraphicsModifier(t *testing.T) { }, { description: "devices with graphics capability creates modifier", - cudaImage: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "graphics", }, @@ -63,7 +63,7 @@ func TestGraphicsModifier(t *testing.T) { }, { description: "devices with compute,graphics capability creates modifier", - cudaImage: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "compute,graphics", }, @@ -71,7 +71,7 @@ func TestGraphicsModifier(t *testing.T) { }, { description: "devices with display capability creates modifier", - cudaImage: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "display", }, @@ -79,7 +79,7 @@ func TestGraphicsModifier(t *testing.T) { }, { description: "devices with display,graphics capability creates modifier", - cudaImage: image.CUDA{ + envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_DRIVER_CAPABILITIES": "display,graphics", }, @@ -89,7 +89,10 @@ func TestGraphicsModifier(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - required, _ := requiresGraphicsModifier(tc.cudaImage) + image, _ := image.New( + image.WithEnvMap(tc.envmap), + ) + required, _ := requiresGraphicsModifier(image) require.EqualValues(t, tc.expectedRequired, required) }) } diff --git a/internal/modifier/mofed.go b/internal/modifier/mofed.go index 505c2db0..0a81a3a6 100644 --- a/internal/modifier/mofed.go +++ b/internal/modifier/mofed.go @@ -38,7 +38,7 @@ func NewMOFEDModifier(logger logger.Interface, cfg *config.Config, image image.C return nil, nil } - if mofed := image[nvidiaMOFEDEnvvar]; mofed != "enabled" { + if image.Getenv(nvidiaMOFEDEnvvar) != "enabled" { return nil, nil }