Merge pull request #438 from elezar/refactor-driver-root

Create root.Driver instance at first usage
This commit is contained in:
Evan Lezar 2024-04-03 15:11:45 +02:00 committed by GitHub
commit 26e52b8013
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 11 deletions

View File

@ -29,16 +29,12 @@ import (
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) {
if required, reason := requiresGraphicsModifier(image); !required {
logger.Infof("No graphics modifier required: %v", reason)
return nil, nil
}
driver := root.New(
root.WithLogger(logger),
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
)
nvidiaCTKPath := cfg.NVIDIACTKConfig.Path
mounts, err := discover.NewGraphicsMountsDiscoverer(

View File

@ -26,6 +26,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
)
// Run is an entry point that allows for idiomatic handling of errors
@ -76,8 +77,13 @@ func (r rt) Run(argv []string) (rerr error) {
r.logger.Infof("Running with config:\n%+v", cfg)
}
driver := root.New(
root.WithLogger(r.logger),
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
)
r.logger.Debugf("Command line arguments: %v", argv)
runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv)
runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv, driver)
if err != nil {
return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err)
}

View File

@ -23,12 +23,13 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string) (oci.Runtime, error) {
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string, driver *root.Driver) (oci.Runtime, error) {
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes)
if err != nil {
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
@ -44,7 +45,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
specModifier, err := newSpecModifier(logger, cfg, ociSpec)
specModifier, err := newSpecModifier(logger, cfg, ociSpec, driver)
if err != nil {
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
}
@ -61,7 +62,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
}
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
@ -82,7 +83,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return modeModifier, nil
}
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image)
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
if err != nil {
return nil, err
}

View File

@ -29,6 +29,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
)
@ -63,6 +64,9 @@ func TestMain(m *testing.M) {
func TestFactoryMethod(t *testing.T) {
logger, _ := testlog.NewNullLogger()
driver := root.New(
root.WithDriverRoot("/nvidia/driver/root"),
)
testCases := []struct {
description string
@ -143,6 +147,7 @@ func TestFactoryMethod(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
bundleDir := t.TempDir()
specFile, err := os.Create(filepath.Join(bundleDir, "config.json"))
@ -151,7 +156,7 @@ func TestFactoryMethod(t *testing.T) {
argv := []string{"--bundle", bundleDir, "create"}
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv)
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv, driver)
if tc.expectedError {
require.Error(t, err)
} else {