diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index eddb3950..a4e992c1 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -65,26 +65,12 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec( - rawSpec, - image.WithLogger(logger), - image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts), - image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged), - image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes), - ) + mode, image, err := initRuntimeModeAndImage(logger, cfg, ociSpec) if err != nil { return nil, err } - mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) - // We update the mode here so that we can continue passing just the config to other functions. - cfg.NVIDIAContainerRuntimeConfig.Mode = mode - modeModifier, err := newModeModifier(logger, mode, cfg, image) + modeModifier, err := newModeModifier(logger, mode, cfg, *image) if err != nil { return nil, err } @@ -98,13 +84,13 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp case "nvidia-hook-remover": modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger)) case "graphics": - graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver, hookCreator) + graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, *image, driver, hookCreator) if err != nil { return nil, err } modifiers = append(modifiers, graphicsModifier) case "feature-gated": - featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image, driver, hookCreator) + featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, *image, driver, hookCreator) if err != nil { return nil, err } @@ -128,6 +114,45 @@ func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, i return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode) } +// initRuntimeModeAndImage constructs an image from the specified OCI runtime +// specification and runtime config. +// The image is also used to determine the runtime mode to apply. +// If a non-CDI mode is detected we ensure that the image does not process +// annotation devices. +func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (string, *image.CUDA, error) { + rawSpec, err := ociSpec.Load() + if err != nil { + return "", nil, fmt.Errorf("failed to load OCI spec: %v", err) + } + + image, err := image.NewCUDAImageFromSpec( + rawSpec, + image.WithLogger(logger), + image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts), + image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged), + image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes), + ) + if err != nil { + return "", nil, err + } + + mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) + // We update the mode here so that we can continue passing just the config to other functions. + cfg.NVIDIAContainerRuntimeConfig.Mode = mode + + if mode == "cdi" || len(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes) == 0 { + return mode, &image, nil + } + + // For non-cdi modes we explicitly set the annotation prefixes to nil and + // call this function again to force a reconstruction of the image. + // Note that since the mode is now explicitly set, we will effectively skip + // the mode resolution. + cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes = nil + + return initRuntimeModeAndImage(logger, cfg, ociSpec) +} + // supportedModifierTypes returns the modifiers supported for a specific runtime mode. func supportedModifierTypes(mode string) []string { switch mode {