Merge pull request #438 from elezar/refactor-driver-root

Create root.Driver instance at first usage
This commit is contained in:
Evan Lezar 2024-04-03 15:11:45 +02:00 committed by GitHub
commit 26e52b8013
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 11 deletions

View File

@ -29,16 +29,12 @@ import (
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) {
if required, reason := requiresGraphicsModifier(image); !required { if required, reason := requiresGraphicsModifier(image); !required {
logger.Infof("No graphics modifier required: %v", reason) logger.Infof("No graphics modifier required: %v", reason)
return nil, nil return nil, nil
} }
driver := root.New(
root.WithLogger(logger),
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
)
nvidiaCTKPath := cfg.NVIDIACTKConfig.Path nvidiaCTKPath := cfg.NVIDIACTKConfig.Path
mounts, err := discover.NewGraphicsMountsDiscoverer( mounts, err := discover.NewGraphicsMountsDiscoverer(

View File

@ -26,6 +26,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
) )
// Run is an entry point that allows for idiomatic handling of errors // Run is an entry point that allows for idiomatic handling of errors
@ -76,8 +77,13 @@ func (r rt) Run(argv []string) (rerr error) {
r.logger.Infof("Running with config:\n%+v", cfg) r.logger.Infof("Running with config:\n%+v", cfg)
} }
driver := root.New(
root.WithLogger(r.logger),
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
)
r.logger.Debugf("Command line arguments: %v", argv) r.logger.Debugf("Command line arguments: %v", argv)
runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv) runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv, driver)
if err != nil { if err != nil {
return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err) return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err)
} }

View File

@ -23,12 +23,13 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
) )
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger // newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string) (oci.Runtime, error) { func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string, driver *root.Driver) (oci.Runtime, error) {
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes) lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes)
if err != nil { if err != nil {
return nil, fmt.Errorf("error constructing low-level runtime: %v", err) return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
@ -44,7 +45,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
return nil, fmt.Errorf("error constructing OCI specification: %v", err) return nil, fmt.Errorf("error constructing OCI specification: %v", err)
} }
specModifier, err := newSpecModifier(logger, cfg, ociSpec) specModifier, err := newSpecModifier(logger, cfg, ociSpec, driver)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
} }
@ -61,7 +62,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
} }
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) {
rawSpec, err := ociSpec.Load() rawSpec, err := ociSpec.Load()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err) return nil, fmt.Errorf("failed to load OCI spec: %v", err)
@ -82,7 +83,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return modeModifier, nil return modeModifier, nil
} }
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image) graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -29,6 +29,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test" "github.com/NVIDIA/nvidia-container-toolkit/internal/test"
) )
@ -63,6 +64,9 @@ func TestMain(m *testing.M) {
func TestFactoryMethod(t *testing.T) { func TestFactoryMethod(t *testing.T) {
logger, _ := testlog.NewNullLogger() logger, _ := testlog.NewNullLogger()
driver := root.New(
root.WithDriverRoot("/nvidia/driver/root"),
)
testCases := []struct { testCases := []struct {
description string description string
@ -143,6 +147,7 @@ func TestFactoryMethod(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
bundleDir := t.TempDir() bundleDir := t.TempDir()
specFile, err := os.Create(filepath.Join(bundleDir, "config.json")) specFile, err := os.Create(filepath.Join(bundleDir, "config.json"))
@ -151,7 +156,7 @@ func TestFactoryMethod(t *testing.T) {
argv := []string{"--bundle", bundleDir, "create"} argv := []string{"--bundle", bundleDir, "create"}
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv) _, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv, driver)
if tc.expectedError { if tc.expectedError {
require.Error(t, err) require.Error(t, err)
} else { } else {