Use image.CUDA instead of raw spec for CDI modifier

This minor refactor aligns the construction of the CDI modifier
with other modifiers. It is update to accept an CUDA image / container
and this is used to extract device information instead of re-reading the
raw OCI spec.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2024-02-05 23:33:08 +01:00
parent 60b9e41273
commit ba87830cc0
4 changed files with 39 additions and 37 deletions

View File

@ -24,9 +24,7 @@ import (
)
type builder struct {
annotations map[string]string
env map[string]string
mounts []specs.Mount
CUDA
disableRequire bool
}
@ -102,6 +100,14 @@ func WithEnvMap(env map[string]string) Option {
}
}
// WithIsPrivileged sets whether a container is privileged or not.
func WithIsPrivileged(isPrivileged bool) Option {
return func(b *builder) error {
b.isPrivileged = isPrivileged
return nil
}
}
// WithMounts sets the mounts associated with the CUDA image.
func WithMounts(mounts []specs.Mount) Option {
return func(b *builder) error {

View File

@ -40,9 +40,10 @@ const (
// a map of environment variable to values that can be used to perform lookups
// such as requirements.
type CUDA struct {
annotations map[string]string
env map[string]string
mounts []specs.Mount
annotations map[string]string
env map[string]string
mounts []specs.Mount
isPrivileged bool
}
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
@ -57,6 +58,7 @@ func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) {
WithEnv(env),
WithAnnotations(spec.Annotations),
WithMounts(spec.Mounts),
WithIsPrivileged(IsPrivileged(spec)),
)
}
@ -66,6 +68,11 @@ func NewCUDAImageFromEnv(env []string) (CUDA, error) {
return New(WithEnv(env))
}
// IsPrivileged indicates whether the container was started with elevated privileged.
func (i CUDA) IsPrivileged() bool {
return i.isPrivileged
}
// Getenv returns the value of the specified environment variable.
// If the environment variable is not specified, an empty string is returned.
func (i CUDA) Getenv(key string) string {

View File

@ -34,8 +34,8 @@ import (
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
// used to select the devices to include.
func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
devices, err := getDevicesFromSpec(logger, ociSpec, cfg)
func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
devices, err := getDevicesFromContainer(logger, cfg, image)
if err != nil {
return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err)
}
@ -65,17 +65,8 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
)
}
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
container, err := image.NewCUDAImageFromSpec(rawSpec)
if err != nil {
return nil, err
}
annotationDevices, err := container.CDIDevicesFromAnnotations(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes...)
func getDevicesFromContainer(logger logger.Interface, cfg *config.Config, image image.CUDA) ([]string, error) {
annotationDevices, err := image.CDIDevicesFromAnnotations(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes...)
if err != nil {
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
}
@ -83,13 +74,13 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
return annotationDevices, nil
}
if cfg.AcceptDeviceListAsVolumeMounts {
mountDevices := container.CDIDevicesFromMounts()
mountDevices := image.CDIDevicesFromMounts()
if len(mountDevices) > 0 {
return mountDevices, nil
}
}
envDevices := container.DevicesFromEnvvars(visibleDevicesEnvvar)
envDevices := image.DevicesFromEnvvars(visibleDevicesEnvvar)
var devices []string
seen := make(map[string]bool)
@ -108,7 +99,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
return nil, nil
}
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged(rawSpec) {
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged() {
return devices, nil
}

View File

@ -43,8 +43,16 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
if err != nil {
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
image, err := image.NewCUDAImageFromSpec(rawSpec)
if err != nil {
return nil, err
}
specModifier, err := newSpecModifier(logger, cfg, ociSpec)
specModifier, err := newSpecModifier(logger, cfg, image)
if err != nil {
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
}
@ -61,19 +69,9 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
}
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
image, err := image.NewCUDAImageFromSpec(rawSpec)
if err != nil {
return nil, err
}
func newSpecModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, image)
modeModifier, err := newModeModifier(logger, mode, cfg, image)
if err != nil {
return nil, err
}
@ -100,14 +98,14 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return modifiers, nil
}
func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec, image image.CUDA) (oci.SpecModifier, error) {
func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
switch mode {
case "legacy":
return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil
case "csv":
return modifier.NewCSVModifier(logger, cfg, image)
case "cdi":
return modifier.NewCDIModifier(logger, cfg, ociSpec)
return modifier.NewCDIModifier(logger, cfg, image)
}
return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode)