From 6d220ed9a2837fd26ded363daa4cd2f9200a9b77 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 6 Mar 2023 13:40:21 +0200 Subject: [PATCH] Rework selection of devices in CDI mode The following changes are made: * The default-cdi-kind config option is used to convert an envvar entry to a fully-qualified device name * If annotation devices exist, these are used instead of the envvar devices. * The `all` device is no longer treated as a special case and MUST exist in the CDI spec. Signed-off-by: Evan Lezar --- internal/modifier/cdi.go | 55 ++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index cffe2967..eb15b4bf 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -18,7 +18,6 @@ package modifier import ( "fmt" - "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" @@ -38,7 +37,7 @@ type cdiModifier struct { // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES enviroment variable is // used to select the devices to include. func NewCDIModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - devices, err := getDevicesFromSpec(ociSpec) + devices, err := getDevicesFromSpec(logger, ociSpec, cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind) if err != nil { return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err) } @@ -46,6 +45,7 @@ func NewCDIModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) logger.Debugf("No devices requested; no modification required.") return nil, nil } + logger.Debugf("Creating CDI modifier for devices: %v", devices) specDirs := cdi.DefaultSpecDirs if len(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs) > 0 { @@ -61,34 +61,36 @@ func NewCDIModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) return m, nil } -func getDevicesFromSpec(ociSpec oci.Spec) ([]string, error) { +func getDevicesFromSpec(logger *logrus.Logger, ociSpec oci.Spec, defaultKind string) ([]string, error) { rawSpec, err := ociSpec.Load() if err != nil { return nil, fmt.Errorf("failed to load OCI spec: %v", err) } + _, annotationDevices, err := cdi.ParseAnnotations(rawSpec.Annotations) + if err != nil { + return nil, fmt.Errorf("failed to parse container annotations: %v", err) + } + if len(annotationDevices) > 0 { + return annotationDevices, nil + } + image, err := image.NewCUDAImageFromSpec(rawSpec) if err != nil { return nil, err } - envDevices := image.DevicesFromEnvvars(visibleDevicesEnvvar) - _, annotationDevices, err := cdi.ParseAnnotations(rawSpec.Annotations) - if err != nil { - return nil, fmt.Errorf("failed to parse container annotations: %v", err) - } - - uniqueDevices := make(map[string]struct{}) - for _, name := range append(envDevices.List(), annotationDevices...) { - if !cdi.IsQualifiedName(name) { - name = cdi.QualifiedName("nvidia.com", "gpu", name) - } - uniqueDevices[name] = struct{}{} - } - var devices []string - for name := range uniqueDevices { + seen := make(map[string]bool) + for _, name := range envDevices.List() { + if !cdi.IsQualifiedName(name) { + name = fmt.Sprintf("%s=%s", defaultKind, name) + } + if seen[name] { + logger.Debugf("Ignoring duplicate device %q", name) + continue + } devices = append(devices, name) } @@ -105,21 +107,8 @@ func (m cdiModifier) Modify(spec *specs.Spec) error { m.logger.Debugf("The following error was triggered when refreshing the CDI registry: %v", err) } - devices := m.devices - for _, d := range devices { - if d == "nvidia.com/gpu=all" { - devices = []string{} - for _, candidate := range registry.DeviceDB().ListDevices() { - if strings.HasPrefix(candidate, "nvidia.com/gpu=") { - devices = append(devices, candidate) - } - } - break - } - } - - m.logger.Debugf("Injecting devices using CDI: %v", devices) - _, err := registry.InjectDevices(spec, devices...) + m.logger.Debugf("Injecting devices using CDI: %v", m.devices) + _, err := registry.InjectDevices(spec, m.devices...) if err != nil { return fmt.Errorf("failed to inject CDI devices: %v", err) }