From d92300506c747a53ea0a1d19daf1820d44d63b4f Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 5 Jun 2023 16:01:20 +0200 Subject: [PATCH] Construct CUDA image object once Signed-off-by: Evan Lezar --- internal/modifier/gds.go | 12 +----------- internal/modifier/graphics.go | 12 +----------- internal/modifier/mofed.go | 12 +----------- internal/runtime/runtime_factory.go | 17 ++++++++++++++--- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/internal/modifier/gds.go b/internal/modifier/gds.go index 5334346c..9ef15992 100644 --- a/internal/modifier/gds.go +++ b/internal/modifier/gds.go @@ -32,17 +32,7 @@ const ( // NewGDSModifier creates the modifiers for GDS devices. // If the spec does not contain the NVIDIA_GDS=enabled environment variable no changes are made. -func NewGDSModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec(rawSpec) - if err != nil { - return nil, err - } - +func NewGDSModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index e80de124..57776a72 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -28,17 +28,7 @@ import ( // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. -func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec(rawSpec) - if err != nil { - return nil, err - } - +func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { if required, reason := requiresGraphicsModifier(image); !required { logger.Infof("No graphics modifier required: %v", reason) return nil, nil diff --git a/internal/modifier/mofed.go b/internal/modifier/mofed.go index 92796506..5cc69169 100644 --- a/internal/modifier/mofed.go +++ b/internal/modifier/mofed.go @@ -32,17 +32,7 @@ const ( // NewMOFEDModifier creates the modifiers for MOFED devices. // If the spec does not contain the NVIDIA_MOFED=enabled environment variable no changes are made. -func NewMOFEDModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec(rawSpec) - if err != nil { - return nil, err - } - +func NewMOFEDModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index dd1a847a..ed32585b 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier" @@ -61,6 +62,16 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { + rawSpec, err := ociSpec.Load() + if err != nil { + return nil, fmt.Errorf("failed to load OCI spec: %v", err) + } + + image, err := image.NewCUDAImageFromSpec(rawSpec) + if err != nil { + return nil, err + } + mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, argv) if err != nil { @@ -71,17 +82,17 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return modeModifier, nil } - graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, ociSpec) + graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image) if err != nil { return nil, err } - gdsModifier, err := modifier.NewGDSModifier(logger, cfg, ociSpec) + gdsModifier, err := modifier.NewGDSModifier(logger, cfg, image) if err != nil { return nil, err } - mofedModifier, err := modifier.NewMOFEDModifier(logger, cfg, ociSpec) + mofedModifier, err := modifier.NewMOFEDModifier(logger, cfg, image) if err != nil { return nil, err }