diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index d7188f9b..76e8dab7 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -20,12 +20,15 @@ import ( "fmt" "strings" + "tags.cncf.io/container-device-interface/pkg/parser" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" - "tags.cncf.io/container-device-interface/pkg/parser" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" ) // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the @@ -42,6 +45,14 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe } logger.Debugf("Creating CDI modifier for devices: %v", devices) + automaticDevices := filterAutomaticDevices(devices) + if len(automaticDevices) != len(devices) && len(automaticDevices) > 0 { + return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices") + } + if len(automaticDevices) > 0 { + return newAutomaticCDISpecModifier(logger, cfg, automaticDevices) + } + return cdi.New( cdi.WithLogger(logger), cdi.WithDevices(devices...), @@ -133,3 +144,52 @@ func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]s return annotationDevices, nil } + +// filterAutomaticDevices searches for "automatic" device names in the input slice. +// "Automatic" devices are a well-defined list of CDI device names which, when requested, +// trigger the generation of a CDI spec at runtime. This removes the need to generate a +// CDI spec on the system a-priori as well as keep it up-to-date. +func filterAutomaticDevices(devices []string) []string { + var automatic []string + for _, device := range devices { + if device == "runtime.nvidia.com/gpu=all" { + automatic = append(automatic, device) + } + } + return automatic +} + +func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) { + logger.Debugf("Generating in-memory CDI specs for devices %v", devices) + spec, err := generateAutomaticCDISpec(logger, cfg, devices) + if err != nil { + return nil, fmt.Errorf("failed to generate CDI spec: %w", err) + } + cdiModifier, err := cdi.New( + cdi.WithLogger(logger), + cdi.WithSpec(spec.Raw()), + ) + if err != nil { + return nil, fmt.Errorf("failed to construct CDI modifier: %w", err) + } + + return cdiModifier, nil +} + +// TODO: use the requested devices when generating the CDI spec once we add +// automatic CDI generation for more than just the 'runtime.nvidia.com/gpu=all' +// device +func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) { + cdilib, err := nvcdi.New( + nvcdi.WithLogger(logger), + nvcdi.WithNVIDIACTKPath(cfg.NVIDIACTKConfig.Path), + nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), + nvcdi.WithVendor("runtime.nvidia.com"), + nvcdi.WithClass("gpu"), + ) + if err != nil { + return nil, fmt.Errorf("failed to construct CDI library: %w", err) + } + + return cdilib.GetSpec() +}