diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 023c1d7b..1f3f2c48 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -29,16 +29,12 @@ import ( // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. -func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { +func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) { if required, reason := requiresGraphicsModifier(image); !required { logger.Infof("No graphics modifier required: %v", reason) return nil, nil } - driver := root.New( - root.WithLogger(logger), - root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), - ) nvidiaCTKPath := cfg.NVIDIACTKConfig.Path mounts, err := discover.NewGraphicsMountsDiscoverer( diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index c98f9658..4b00b772 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -26,6 +26,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/info" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" ) // Run is an entry point that allows for idiomatic handling of errors @@ -76,8 +77,13 @@ func (r rt) Run(argv []string) (rerr error) { r.logger.Infof("Running with config:\n%+v", cfg) } + driver := root.New( + root.WithLogger(r.logger), + root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), + ) + r.logger.Debugf("Command line arguments: %v", argv) - runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv) + runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv, driver) if err != nil { return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err) } diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index a8bdbbf6..5bd7983a 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -23,12 +23,13 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) // newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger -func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string) (oci.Runtime, error) { +func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string, driver *root.Driver) (oci.Runtime, error) { lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes) if err != nil { return nil, fmt.Errorf("error constructing low-level runtime: %v", err) @@ -44,7 +45,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv return nil, fmt.Errorf("error constructing OCI specification: %v", err) } - specModifier, err := newSpecModifier(logger, cfg, ociSpec) + specModifier, err := newSpecModifier(logger, cfg, ociSpec, driver) if err != nil { return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) } @@ -61,7 +62,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv } // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. -func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { +func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) { rawSpec, err := ociSpec.Load() if err != nil { return nil, fmt.Errorf("failed to load OCI spec: %v", err) @@ -82,7 +83,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return modeModifier, nil } - graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image) + graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver) if err != nil { return nil, err } diff --git a/internal/runtime/runtime_factory_test.go b/internal/runtime/runtime_factory_test.go index 33056fa3..dd052d8f 100644 --- a/internal/runtime/runtime_factory_test.go +++ b/internal/runtime/runtime_factory_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/test" ) @@ -63,6 +64,9 @@ func TestMain(m *testing.M) { func TestFactoryMethod(t *testing.T) { logger, _ := testlog.NewNullLogger() + driver := root.New( + root.WithDriverRoot("/nvidia/driver/root"), + ) testCases := []struct { description string @@ -143,6 +147,7 @@ func TestFactoryMethod(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + bundleDir := t.TempDir() specFile, err := os.Create(filepath.Join(bundleDir, "config.json")) @@ -151,7 +156,7 @@ func TestFactoryMethod(t *testing.T) { argv := []string{"--bundle", bundleDir, "create"} - _, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv) + _, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv, driver) if tc.expectedError { require.Error(t, err) } else {