Create root.Driver instance at first usage

This allows for testing through injection of the driver root.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar
2024-04-03 14:59:30 +02:00
parent 413da20838
commit 011c658945
4 changed files with 19 additions and 11 deletions

View File

@@ -23,12 +23,13 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string) (oci.Runtime, error) {
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string, driver *root.Driver) (oci.Runtime, error) {
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes)
if err != nil {
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
@@ -44,7 +45,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
specModifier, err := newSpecModifier(logger, cfg, ociSpec)
specModifier, err := newSpecModifier(logger, cfg, ociSpec, driver)
if err != nil {
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
}
@@ -61,7 +62,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
}
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
@@ -82,7 +83,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return modeModifier, nil
}
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image)
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
if err != nil {
return nil, err
}