From f17d42424874aa016f68d0726b360ee3d6106407 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 12 Jun 2025 14:38:11 +0200 Subject: [PATCH] Construct container info once Signed-off-by: Evan Lezar --- internal/modifier/cdi.go | 24 ++++-------------------- internal/runtime/runtime_factory.go | 9 ++++++--- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index c1ab1818..6139a3a8 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -33,8 +33,8 @@ import ( // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is // used to select the devices to include. -func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - devices, err := getDevicesFromSpec(logger, ociSpec, cfg) +func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { + devices, err := getDevicesFromImage(logger, cfg, image) if err != nil { return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err) } @@ -64,23 +64,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe ) } -func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - container, err := image.NewCUDAImageFromSpec( - rawSpec, - image.WithLogger(logger), - image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts), - image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged), - image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes), - ) - if err != nil { - return nil, err - } - +func getDevicesFromImage(logger logger.Interface, cfg *config.Config, container image.CUDA) ([]string, error) { annotationDevices, err := getAnnotationDevices(container) if err != nil { return nil, fmt.Errorf("failed to parse container annotations: %v", err) @@ -113,7 +97,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C return nil, nil } - if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) { + if cfg.AcceptEnvvarUnprivileged || container.IsPrivileged() { return devices, nil } diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index a7b454a9..63d127f7 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -73,6 +73,9 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp image, err := image.NewCUDAImageFromSpec( rawSpec, image.WithLogger(logger), + image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts), + image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged), + image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes), ) if err != nil { return nil, err @@ -83,7 +86,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) // We update the mode here so that we can continue passing just the config to other functions. cfg.NVIDIAContainerRuntimeConfig.Mode = mode - modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, image) + modeModifier, err := newModeModifier(logger, mode, cfg, image) if err != nil { return nil, err } @@ -113,14 +116,14 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return modifiers, nil } -func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec, image image.CUDA) (oci.SpecModifier, error) { +func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { switch mode { case "legacy": return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil case "csv": return modifier.NewCSVModifier(logger, cfg, image) case "cdi": - return modifier.NewCDIModifier(logger, cfg, ociSpec) + return modifier.NewCDIModifier(logger, cfg, image) } return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode)