From aca0c7bc5a3f76da4e751963ccb52b4edd3a0ca4 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 26 Oct 2022 12:37:23 +0200 Subject: [PATCH] Add Devices abstraction to CUDA image This change adds a Devices abstraction to the CUDA image utilities. This allows for checking whether a devices is selected, for example. Signed-off-by: Evan Lezar --- .../container_config.go | 2 +- internal/config/image/cuda_image.go | 12 +- internal/config/image/devices.go | 125 ++++++++++++++++++ internal/modifier/cdi.go | 2 +- internal/modifier/csv.go | 2 +- internal/modifier/gds.go | 2 +- internal/modifier/graphics.go | 2 +- internal/modifier/mofed.go | 2 +- 8 files changed, 135 insertions(+), 14 deletions(-) create mode 100644 internal/config/image/devices.go diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index 2168c5e2..1956ba54 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -167,7 +167,7 @@ func getDevicesFromEnvvar(image image.CUDA, swarmResourceEnvvars []string) *stri // Build a list of envvars to consider. Note that the Swarm Resource envvars have a higher precedence. envVars := append(swarmResourceEnvvars, envNVVisibleDevices) - devices := image.DevicesFromEnvvars(envVars...) + devices := image.DevicesFromEnvvars(envVars...).List() if len(devices) == 0 { return nil } diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index 097f5d9e..ebd3fdd2 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -114,7 +114,7 @@ func (i CUDA) HasDisableRequire() bool { } // DevicesFromEnvvars returns the devices requested by the image through environment variables -func (i CUDA) DevicesFromEnvvars(envVars ...string) []string { +func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices { // Grab a reference to devices from the first envvar // in the list that actually exists in the environment. var devices *string @@ -127,20 +127,16 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) []string { // Environment variable unset with legacy image: default to "all". if devices == nil && i.IsLegacy() { - return []string{"all"} + return newVisibleDevices("all") } // Environment variable unset or empty or "void": return nil if devices == nil || len(*devices) == 0 || *devices == "void" { - return nil + return newVisibleDevices("void") } // Environment variable set to "none": reset to "". - if *devices == "none" { - return []string{""} - } - - return strings.Split(*devices, ",") + return newVisibleDevices(*devices) } // GetDriverCapabilities returns the requested driver capabilities. diff --git a/internal/config/image/devices.go b/internal/config/image/devices.go new file mode 100644 index 00000000..6f3d00b6 --- /dev/null +++ b/internal/config/image/devices.go @@ -0,0 +1,125 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package image + +import ( + "strings" +) + +// VisibleDevices represents the devices selected in a container image +// through the NVIDIA_VISIBLE_DEVICES or other environment variables +type VisibleDevices interface { + List() []string + Has(string) bool +} + +var _ VisibleDevices = (*all)(nil) +var _ VisibleDevices = (*none)(nil) +var _ VisibleDevices = (*void)(nil) +var _ VisibleDevices = (*devices)(nil) + +// newVisibleDevices creates a VisibleDevices based on the value of the specified envvar. +func newVisibleDevices(envvar string) VisibleDevices { + if envvar == "all" { + return all{} + } + if envvar == "none" { + return none{} + } + if envvar == "" || envvar == "void" { + return void{} + } + + return newDevices(envvar) +} + +type all struct{} + +// List returns ["all"] for all devices +func (a all) List() []string { + return []string{"all"} +} + +// Has for all devices is true for any id except the empty ID +func (a all) Has(id string) bool { + return id != "" +} + +type none struct{} + +// List returns [""] for the none devices +func (n none) List() []string { + return []string{""} +} + +// Has for none devices is false for any id +func (n none) Has(id string) bool { + return false +} + +type void struct { + none +} + +// List returns nil for the void devices +func (v void) List() []string { + return nil +} + +type devices struct { + len int + lookup map[string]int +} + +func newDevices(idOrCommaSeparated ...string) devices { + lookup := make(map[string]int) + + i := 0 + for _, commaSeparated := range idOrCommaSeparated { + for _, id := range strings.Split(commaSeparated, ",") { + lookup[id] = i + i++ + } + } + + d := devices{ + len: i, + lookup: lookup, + } + return d +} + +// List returns the list of requested devices +func (d devices) List() []string { + list := make([]string, d.len) + + for id, i := range d.lookup { + list[i] = id + } + + return list +} + +// Has checks whether the specified ID is in the set of requested devices +func (d devices) Has(id string) bool { + if id == "" { + return false + } + + _, exist := d.lookup[id] + return exist +} diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 96d0c665..cffe2967 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -80,7 +80,7 @@ func getDevicesFromSpec(ociSpec oci.Spec) ([]string, error) { } uniqueDevices := make(map[string]struct{}) - for _, name := range append(envDevices, annotationDevices...) { + for _, name := range append(envDevices.List(), annotationDevices...) { if !cdi.IsQualifiedName(name) { name = cdi.QualifiedName("nvidia.com", "gpu", name) } diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index d2a2f7ad..57f8deea 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -55,7 +55,7 @@ func NewCSVModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) return nil, err } - if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { + if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil } diff --git a/internal/modifier/gds.go b/internal/modifier/gds.go index a55d2bef..78c0e38a 100644 --- a/internal/modifier/gds.go +++ b/internal/modifier/gds.go @@ -43,7 +43,7 @@ func NewGDSModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) return nil, err } - if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { + if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil } diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 722c2ae0..f094d78f 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -40,7 +40,7 @@ func NewGraphicsModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci. return nil, err } - if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { + if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil } diff --git a/internal/modifier/mofed.go b/internal/modifier/mofed.go index 62235e7a..abdf8baf 100644 --- a/internal/modifier/mofed.go +++ b/internal/modifier/mofed.go @@ -43,7 +43,7 @@ func NewMOFEDModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spe return nil, err } - if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices) == 0 { + if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil }