Merge branch 'add-dev-dri' into 'main'

Inject DRM device nodes into containers when Graphics or Display capabilities are requested

See merge request nvidia/container-toolkit/container-toolkit!235
This commit is contained in:
Evan Lezar 2022-11-03 09:31:03 +00:00
commit f7021d84b5
16 changed files with 708 additions and 36 deletions

View File

@ -167,7 +167,7 @@ func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) *stri
// Build a list of envvars to consider. Note that the Swarm Resource envvars have a higher precedence. // Build a list of envvars to consider. Note that the Swarm Resource envvars have a higher precedence.
envVars := append(swarmResourceEnvvars, envNVVisibleDevices) envVars := append(swarmResourceEnvvars, envNVVisibleDevices)
devices := image.DevicesFromEnvvars(envVars...) devices := image.DevicesFromEnvvars(envVars...).List()
if len(devices) == 0 { if len(devices) == 0 {
return nil return nil
} }

View File

@ -74,7 +74,7 @@ func (m command) build() *cli.Command {
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "link", Name: "link",
Usage: "Specify a specific link to create. The link is specified as source:target", Usage: "Specify a specific link to create. The link is specified as target::link",
Destination: &cfg.links, Destination: &cfg.links,
}, },
&cli.StringFlag{ &cli.StringFlag{
@ -145,7 +145,7 @@ func (m command) run(c *cli.Context, cfg *config) error {
links := cfg.links.Value() links := cfg.links.Value()
for _, l := range links { for _, l := range links {
parts := strings.Split(l, ":") parts := strings.Split(l, "::")
if len(parts) != 2 { if len(parts) != 2 {
m.logger.Warnf("Invalid link specification %v", l) m.logger.Warnf("Invalid link specification %v", l)
continue continue

View File

@ -0,0 +1,54 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package image
// DriverCapability represents the possible values of NVIDIA_DRIVER_CAPABILITIES
type DriverCapability string
// Constants for the supported driver capabilities
const (
DriverCapabilityAll DriverCapability = "all"
DriverCapabilityCompat32 DriverCapability = "compat32"
DriverCapabilityCompute DriverCapability = "compute"
DriverCapabilityDisplay DriverCapability = "display"
DriverCapabilityGraphics DriverCapability = "graphics"
DriverCapabilityNgx DriverCapability = "ngx"
DriverCapabilityUtility DriverCapability = "utility"
DriverCapabilityVideo DriverCapability = "video"
)
// DriverCapabilities represents the NVIDIA_DRIVER_CAPABILITIES set for the specified image.
type DriverCapabilities map[DriverCapability]bool
// Has check whether the specified capability is selected.
func (c DriverCapabilities) Has(capability DriverCapability) bool {
if c[DriverCapabilityAll] {
return true
}
return c[capability]
}
// Any checks whether any of the specified capabilites are set
func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool {
for _, cap := range capabilities {
if c.Has(cap) {
return true
}
}
return false
}

View File

@ -26,11 +26,12 @@ import (
) )
const ( const (
envCUDAVersion = "CUDA_VERSION" envCUDAVersion = "CUDA_VERSION"
envNVRequirePrefix = "NVIDIA_REQUIRE_" envNVRequirePrefix = "NVIDIA_REQUIRE_"
envNVRequireCUDA = envNVRequirePrefix + "CUDA" envNVRequireCUDA = envNVRequirePrefix + "CUDA"
envNVRequireJetpack = envNVRequirePrefix + "JETPACK" envNVRequireJetpack = envNVRequirePrefix + "JETPACK"
envNVDisableRequire = "NVIDIA_DISABLE_REQUIRE" envNVDisableRequire = "NVIDIA_DISABLE_REQUIRE"
envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES"
) )
// CUDA represents a CUDA image that can be used for GPU computing. This wraps // CUDA represents a CUDA image that can be used for GPU computing. This wraps
@ -113,7 +114,7 @@ func (i CUDA) HasDisableRequire() bool {
} }
// DevicesFromEnvvars returns the devices requested by the image through environment variables // DevicesFromEnvvars returns the devices requested by the image through environment variables
func (i CUDA) DevicesFromEnvvars(envVars ...string) []string { func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
// Grab a reference to devices from the first envvar // Grab a reference to devices from the first envvar
// in the list that actually exists in the environment. // in the list that actually exists in the environment.
var devices *string var devices *string
@ -126,20 +127,28 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) []string {
// Environment variable unset with legacy image: default to "all". // Environment variable unset with legacy image: default to "all".
if devices == nil && i.IsLegacy() { if devices == nil && i.IsLegacy() {
return []string{"all"} return newVisibleDevices("all")
} }
// Environment variable unset or empty or "void": return nil // Environment variable unset or empty or "void": return nil
if devices == nil || len(*devices) == 0 || *devices == "void" { if devices == nil || len(*devices) == 0 || *devices == "void" {
return nil return newVisibleDevices("void")
} }
// Environment variable set to "none": reset to "". // Environment variable set to "none": reset to "".
if *devices == "none" { return newVisibleDevices(*devices)
return []string{""} }
// GetDriverCapabilities returns the requested driver capabilities.
func (i CUDA) GetDriverCapabilities() DriverCapabilities {
env := i[envNVDriverCapabilities]
capabilites := make(DriverCapabilities)
for _, c := range strings.Split(env, ",") {
capabilites[DriverCapability(c)] = true
} }
return strings.Split(*devices, ",") return capabilites
} }
func (i CUDA) legacyVersion() (string, error) { func (i CUDA) legacyVersion() (string, error) {

View File

@ -0,0 +1,125 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package image
import (
"strings"
)
// VisibleDevices represents the devices selected in a container image
// through the NVIDIA_VISIBLE_DEVICES or other environment variables
type VisibleDevices interface {
List() []string
Has(string) bool
}
var _ VisibleDevices = (*all)(nil)
var _ VisibleDevices = (*none)(nil)
var _ VisibleDevices = (*void)(nil)
var _ VisibleDevices = (*devices)(nil)
// newVisibleDevices creates a VisibleDevices based on the value of the specified envvar.
func newVisibleDevices(envvar string) VisibleDevices {
if envvar == "all" {
return all{}
}
if envvar == "none" {
return none{}
}
if envvar == "" || envvar == "void" {
return void{}
}
return newDevices(envvar)
}
type all struct{}
// List returns ["all"] for all devices
func (a all) List() []string {
return []string{"all"}
}
// Has for all devices is true for any id except the empty ID
func (a all) Has(id string) bool {
return id != ""
}
type none struct{}
// List returns [""] for the none devices
func (n none) List() []string {
return []string{""}
}
// Has for none devices is false for any id
func (n none) Has(id string) bool {
return false
}
type void struct {
none
}
// List returns nil for the void devices
func (v void) List() []string {
return nil
}
type devices struct {
len int
lookup map[string]int
}
func newDevices(idOrCommaSeparated ...string) devices {
lookup := make(map[string]int)
i := 0
for _, commaSeparated := range idOrCommaSeparated {
for _, id := range strings.Split(commaSeparated, ",") {
lookup[id] = i
i++
}
}
d := devices{
len: i,
lookup: lookup,
}
return d
}
// List returns the list of requested devices
func (d devices) List() []string {
list := make([]string, d.len)
for id, i := range d.lookup {
list[i] = id
}
return list
}
// Has checks whether the specified ID is in the set of requested devices
func (d devices) Has(id string) bool {
if id == "" {
return false
}
_, exist := d.lookup[id]
return exist
}

View File

@ -0,0 +1,62 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package discover
import "github.com/sirupsen/logrus"
// Filter defines an interface for filtering discovered entities
type Filter interface {
DeviceIsSelected(device Device) bool
}
// filtered represents a filtered discoverer
type filtered struct {
Discover
logger *logrus.Logger
filter Filter
}
// newFilteredDisoverer creates a discoverer that applies the specified filter to the returned entities of the discoverer
func newFilteredDisoverer(logger *logrus.Logger, applyTo Discover, filter Filter) Discover {
return filtered{
Discover: applyTo,
logger: logger,
filter: filter,
}
}
// Devices returns a filtered list of devices based on the specified filter.
func (d filtered) Devices() ([]Device, error) {
devices, err := d.Discover.Devices()
if err != nil {
return nil, err
}
if d.filter == nil {
return devices, nil
}
var selected []Device
for _, device := range devices {
if d.filter.DeviceIsSelected(device) {
selected = append(selected, device)
}
d.logger.Debugf("skipping device %v", device)
}
return selected, nil
}

View File

@ -18,13 +18,21 @@ package discover
import ( import (
"fmt" "fmt"
"os"
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/drm"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// NewGraphicsDiscoverer returns the discoverer for graphics tools such as Vulkan. // NewGraphicsDiscoverer returns the discoverer for graphics tools such as Vulkan.
func NewGraphicsDiscoverer(logger *logrus.Logger, root string) (Discover, error) { func NewGraphicsDiscoverer(logger *logrus.Logger, devices image.VisibleDevices, cfg *Config) (Discover, error) {
root := cfg.Root
locator, err := lookup.NewLibraryLocator(logger, root) locator, err := lookup.NewLibraryLocator(logger, root)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct library locator: %v", err) return nil, fmt.Errorf("failed to construct library locator: %v", err)
@ -54,10 +62,193 @@ func NewGraphicsDiscoverer(logger *logrus.Logger, root string) (Discover, error)
}, },
) )
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, root)
if err != nil {
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
}
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, cfg)
discover := Merge( discover := Merge(
Merge(drmDeviceNodes, drmByPathSymlinks),
libraries, libraries,
jsonMounts, jsonMounts,
) )
return discover, nil return discover, nil
} }
type drmDevicesByPath struct {
None
logger *logrus.Logger
lookup lookup.Locator
nvidiaCTKExecutablePath string
root string
devicesFrom Discover
}
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
func newCreateDRMByPathSymlinks(logger *logrus.Logger, devices Discover, cfg *Config) Discover {
d := drmDevicesByPath{
logger: logger,
lookup: lookup.NewExecutableLocator(logger, cfg.Root),
nvidiaCTKExecutablePath: cfg.NVIDIAContainerToolkitCLIExecutablePath,
root: cfg.Root,
devicesFrom: devices,
}
return &d
}
// Hooks returns a hook to create the symlinks from the required CSV files
func (d drmDevicesByPath) Hooks() ([]Hook, error) {
devices, err := d.devicesFrom.Devices()
if err != nil {
return nil, fmt.Errorf("failed to discover devices for by-path symlinks: %v", err)
}
if len(devices) == 0 {
return nil, nil
}
hookPath := nvidiaCTKDefaultFilePath
targets, err := d.lookup.Locate(d.nvidiaCTKExecutablePath)
if err != nil {
d.logger.Warnf("Failed to locate %v: %v", d.nvidiaCTKExecutablePath, err)
} else if len(targets) == 0 {
d.logger.Warnf("%v not found", d.nvidiaCTKExecutablePath)
} else {
d.logger.Debugf("Found %v candidates: %v", d.nvidiaCTKExecutablePath, targets)
hookPath = targets[0]
}
d.logger.Debugf("Using NVIDIA Container Toolkit CLI path %v", hookPath)
args := []string{hookPath, "hook", "create-symlinks"}
links, err := d.getSpecificLinkArgs(devices)
if err != nil {
return nil, fmt.Errorf("failed to determine specific links: %v", err)
}
for _, l := range links {
args = append(args, "--link", l)
}
h := Hook{
Lifecycle: cdi.CreateContainerHook,
Path: hookPath,
Args: args,
}
return []Hook{h}, nil
}
// getSpecificLinkArgs returns the required specic links that need to be created
func (d drmDevicesByPath) getSpecificLinkArgs(devices []Device) ([]string, error) {
selectedDevices := make(map[string]bool)
for _, d := range devices {
selectedDevices[filepath.Base(d.HostPath)] = true
}
linkLocator := lookup.NewFileLocator(d.logger, d.root)
candidates, err := linkLocator.Locate("/dev/dri/by-path/pci-*-*")
if err != nil {
return nil, fmt.Errorf("failed to locate devices by path: %v", err)
}
var links []string
for _, c := range candidates {
device, err := os.Readlink(c)
if err != nil {
d.logger.Warningf("Failed to evaluate symlink %v; ignoring", c)
continue
}
if selectedDevices[filepath.Base(device)] {
d.logger.Debugf("adding device symlink %v -> %v", c, device)
links = append(links, fmt.Sprintf("%v::%v", device, c))
}
}
return links, nil
}
// newDRMDeviceDiscoverer creates a discoverer for the DRM devices associated with the requested devices.
func newDRMDeviceDiscoverer(logger *logrus.Logger, devices image.VisibleDevices, root string) (Discover, error) {
allDevices := NewDeviceDiscoverer(
logger,
lookup.NewCharDeviceLocator(logger, root),
root,
[]string{
"/dev/dri/card*",
"/dev/dri/renderD*",
},
)
filter, err := newDRMDeviceFilter(logger, devices, root)
if err != nil {
return nil, fmt.Errorf("failed to construct DRM device filter: %v", err)
}
// We return a discoverer that applies the DRM device filter created above to all discovered DRM device nodes.
d := newFilteredDisoverer(
logger,
allDevices,
filter,
)
return d, err
}
// newDRMDeviceFilter creates a filter that matches DRM devices nodes for the visible devices.
func newDRMDeviceFilter(logger *logrus.Logger, devices image.VisibleDevices, root string) (Filter, error) {
gpuInformationPaths, err := proc.GetInformationFilePaths(root)
if err != nil {
return nil, fmt.Errorf("failed to read GPU information: %v", err)
}
var selectedBusIds []string
for _, f := range gpuInformationPaths {
info, err := proc.ParseGPUInformationFile(f)
if err != nil {
return nil, fmt.Errorf("failed to parse %v: %v", f, err)
}
uuid := info[proc.GPUInfoGPUUUID]
busID := info[proc.GPUInfoBusLocation]
minor := info[proc.GPUInfoDeviceMinor]
if devices.Has(minor) || devices.Has(uuid) || devices.Has(busID) {
selectedBusIds = append(selectedBusIds, busID)
}
}
filter := make(selectDeviceByPath)
for _, busID := range selectedBusIds {
drmDeviceNodes, err := drm.GetDeviceNodesByBusID(busID)
if err != nil {
return nil, fmt.Errorf("failed to determine DRM devices for %v: %v", busID, err)
}
for _, drmDeviceNode := range drmDeviceNodes {
filter[filepath.Join(drmDeviceNode)] = true
}
}
return filter, nil
}
// selectDeviceByPath is a filter that allows devices to be selected by the path
type selectDeviceByPath map[string]bool
var _ Filter = (*selectDeviceByPath)(nil)
// DeviceIsSelected determines whether the device's path has been selected
func (s selectDeviceByPath) DeviceIsSelected(device Device) bool {
return s[device.Path]
}
// MountIsSelected is always true
func (s selectDeviceByPath) MountIsSelected(Mount) bool {
return true
}
// HookIsSelected is always true
func (s selectDeviceByPath) HookIsSelected(Hook) bool {
return true
}

View File

@ -117,7 +117,7 @@ func (d symlinks) getSpecificLinkArgs() ([]string, error) {
} }
linkPath := filepath.Join(filepath.Dir(m.Path), link) linkPath := filepath.Join(filepath.Dir(m.Path), link)
links = append(links, "--link", fmt.Sprintf("%v:%v", target, linkPath)) links = append(links, "--link", fmt.Sprintf("%v::%v", target, linkPath))
linkProcessed[link] = true linkProcessed[link] = true
} }

View File

@ -0,0 +1,39 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package drm
import (
"fmt"
"path/filepath"
)
// GetDeviceNodesByBusID returns the DRM devices associated with the specified PCI bus ID
func GetDeviceNodesByBusID(busID string) ([]string, error) {
drmRoot := filepath.Join("/sys/bus/pci/devices", busID, "drm")
matches, err := filepath.Glob(fmt.Sprintf("%s/*", drmRoot))
if err != nil {
return nil, err
}
var drmDeviceNodes []string
for _, m := range matches {
drmDeviceNode := filepath.Join("/dev/dri", filepath.Base(m))
drmDeviceNodes = append(drmDeviceNodes, drmDeviceNode)
}
return drmDeviceNodes, nil
}

View File

@ -0,0 +1,89 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package proc
import (
"bufio"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// GPUInfoField represents the field name for information specified in a GPU's information file
type GPUInfoField string
// The following constants define the fields of interest from the GPU information file
const (
GPUInfoModel = GPUInfoField("Model")
GPUInfoGPUUUID = GPUInfoField("GPU UUID")
GPUInfoBusLocation = GPUInfoField("Bus Location")
GPUInfoDeviceMinor = GPUInfoField("Device Minor")
)
// GPUInfo stores the information for a GPU as determined from its associated information file
type GPUInfo map[GPUInfoField]string
// GetInformationFilePaths returns the list of information files associated with NVIDIA GPUs.
func GetInformationFilePaths(root string) ([]string, error) {
return filepath.Glob(filepath.Join(root, "/proc/driver/nvidia/gpus/*/information"))
}
// ParseGPUInformationFile parses the specified GPU information file and constructs a GPUInfo structure
func ParseGPUInformationFile(path string) (GPUInfo, error) {
infoFile, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open %v: %v", path, err)
}
defer infoFile.Close()
return gpuInfoFrom(infoFile), nil
}
// gpuInfoFrom parses a GPUInfo struct from the specified reader
// An information file has the following strucutre:
// $ cat /proc/driver/nvidia/gpus/0000\:06\:00.0/information
// Model: Tesla V100-SXM2-16GB
// IRQ: 408
// GPU UUID: GPU-edfee158-11c1-52b8-0517-92f30e7fac88
// Video BIOS: 88.00.41.00.01
// Bus Type: PCIe
// DMA Size: 47 bits
// DMA Mask: 0x7fffffffffff
// Bus Location: 0000:06:00.0
// Device Minor: 0
// GPU Excluded: No
func gpuInfoFrom(reader io.Reader) GPUInfo {
info := make(GPUInfo)
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
field := GPUInfoField(parts[0])
value := strings.TrimSpace(parts[1])
info[field] = value
}
return info
}

View File

@ -80,7 +80,7 @@ func getDevicesFromSpec(ociSpec oci.Spec) ([]string, error) {
} }
uniqueDevices := make(map[string]struct{}) uniqueDevices := make(map[string]struct{})
for _, name := range append(envDevices, annotationDevices...) { for _, name := range append(envDevices.List(), annotationDevices...) {
if !cdi.IsQualifiedName(name) { if !cdi.IsQualifiedName(name) {
name = cdi.QualifiedName("nvidia.com", "gpu", name) name = cdi.QualifiedName("nvidia.com", "gpu", name)
} }

View File

@ -55,7 +55,7 @@ func NewCSVModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec)
return nil, err return nil, err
} }
if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 {
logger.Infof("No modification required; no devices requested") logger.Infof("No modification required; no devices requested")
return nil, nil return nil, nil
} }

View File

@ -43,7 +43,7 @@ func NewGDSModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec)
return nil, err return nil, err
} }
if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 {
logger.Infof("No modification required; no devices requested") logger.Infof("No modification required; no devices requested")
return nil, nil return nil, nil
} }

View File

@ -18,7 +18,6 @@ package modifier
import ( import (
"fmt" "fmt"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
@ -40,28 +39,36 @@ func NewGraphicsModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.
return nil, err return nil, err
} }
if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { if required, reason := requiresGraphicsModifier(image); !required {
logger.Infof("No modification required; no devices requested") logger.Infof("No graphics modifier required: %v", reason)
return nil, nil return nil, nil
} }
var hasGraphics bool config := &discover.Config{
for _, c := range strings.Split(image["NVIDIA_DRIVER_CAPABILITIES"], ",") { Root: cfg.NVIDIAContainerCLIConfig.Root,
if c == "graphics" || c == "all" { NVIDIAContainerToolkitCLIExecutablePath: cfg.NVIDIACTKConfig.Path,
hasGraphics = true
break
}
} }
d, err := discover.NewGraphicsDiscoverer(
if !hasGraphics { logger,
logger.Debugf("Capability %q not selected", "graphics") image.DevicesFromEnvvars(visibleDevicesEnvvar),
return nil, nil config,
} )
d, err := discover.NewGraphicsDiscoverer(logger, cfg.NVIDIAContainerCLIConfig.Root)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct discoverer: %v", err) return nil, fmt.Errorf("failed to construct discoverer: %v", err)
} }
return NewModifierFromDiscoverer(logger, d) return NewModifierFromDiscoverer(logger, d)
} }
// requiresGraphicsModifier determines whether a graphics modifier is required.
func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) {
if devices := cudaImage.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 {
return false, "no devices requested"
}
if !cudaImage.GetDriverCapabilities().Any(image.DriverCapabilityGraphics, image.DriverCapabilityDisplay) {
return false, "no required capabilities requested"
}
return true, ""
}

View File

@ -0,0 +1,96 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package modifier
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/stretchr/testify/require"
)
func TestGraphicsModifier(t *testing.T) {
testCases := []struct {
description string
cudaImage image.CUDA
expectedRequired bool
}{
{
description: "empty image does not create modifier",
},
{
description: "devices with no capabilities does not create modifier",
cudaImage: image.CUDA{
"NVIDIA_VISIBLE_DEVICES": "all",
},
},
{
description: "devices with no non-graphics does not create modifier",
cudaImage: image.CUDA{
"NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "compute",
},
},
{
description: "devices with all capabilities creates modifier",
cudaImage: image.CUDA{
"NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "all",
},
expectedRequired: true,
},
{
description: "devices with graphics capability creates modifier",
cudaImage: image.CUDA{
"NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "graphics",
},
expectedRequired: true,
},
{
description: "devices with compute,graphics capability creates modifier",
cudaImage: image.CUDA{
"NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "compute,graphics",
},
expectedRequired: true,
},
{
description: "devices with display capability creates modifier",
cudaImage: image.CUDA{
"NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "display",
},
expectedRequired: true,
},
{
description: "devices with display,graphics capability creates modifier",
cudaImage: image.CUDA{
"NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "display,graphics",
},
expectedRequired: true,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
required, _ := requiresGraphicsModifier(tc.cudaImage)
require.EqualValues(t, tc.expectedRequired, required)
})
}
}

View File

@ -43,7 +43,7 @@ func NewMOFEDModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spe
return nil, err return nil, err
} }
if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 {
logger.Infof("No modification required; no devices requested") logger.Infof("No modification required; no devices requested")
return nil, nil return nil, nil
} }