Construct container info once

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2025-06-12 14:38:11 +02:00
parent 426186c992
commit f17d424248
No known key found for this signature in database
2 changed files with 10 additions and 23 deletions

View File

@ -33,8 +33,8 @@ import (
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
// used to select the devices to include. // used to select the devices to include.
func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
devices, err := getDevicesFromSpec(logger, ociSpec, cfg) devices, err := getDevicesFromImage(logger, cfg, image)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err) return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err)
} }
@ -64,23 +64,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
) )
} }
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) { func getDevicesFromImage(logger logger.Interface, cfg *config.Config, container image.CUDA) ([]string, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
container, err := image.NewCUDAImageFromSpec(
rawSpec,
image.WithLogger(logger),
image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged),
image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes),
)
if err != nil {
return nil, err
}
annotationDevices, err := getAnnotationDevices(container) annotationDevices, err := getAnnotationDevices(container)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse container annotations: %v", err) return nil, fmt.Errorf("failed to parse container annotations: %v", err)
@ -113,7 +97,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
return nil, nil return nil, nil
} }
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) { if cfg.AcceptEnvvarUnprivileged || container.IsPrivileged() {
return devices, nil return devices, nil
} }

View File

@ -73,6 +73,9 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
image, err := image.NewCUDAImageFromSpec( image, err := image.NewCUDAImageFromSpec(
rawSpec, rawSpec,
image.WithLogger(logger), image.WithLogger(logger),
image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged),
image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -83,7 +86,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
// We update the mode here so that we can continue passing just the config to other functions. // We update the mode here so that we can continue passing just the config to other functions.
cfg.NVIDIAContainerRuntimeConfig.Mode = mode cfg.NVIDIAContainerRuntimeConfig.Mode = mode
modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, image) modeModifier, err := newModeModifier(logger, mode, cfg, image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,14 +116,14 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return modifiers, nil return modifiers, nil
} }
func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec, image image.CUDA) (oci.SpecModifier, error) { func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
switch mode { switch mode {
case "legacy": case "legacy":
return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil
case "csv": case "csv":
return modifier.NewCSVModifier(logger, cfg, image) return modifier.NewCSVModifier(logger, cfg, image)
case "cdi": case "cdi":
return modifier.NewCDIModifier(logger, cfg, ociSpec) return modifier.NewCDIModifier(logger, cfg, image)
} }
return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode) return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode)