mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-21 15:57:49 +00:00
Merge pull request #438 from elezar/refactor-driver-root
Create root.Driver instance at first usage
This commit is contained in:
commit
26e52b8013
@ -29,16 +29,12 @@ import (
|
||||
|
||||
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
|
||||
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
|
||||
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
|
||||
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) {
|
||||
if required, reason := requiresGraphicsModifier(image); !required {
|
||||
logger.Infof("No graphics modifier required: %v", reason)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
driver := root.New(
|
||||
root.WithLogger(logger),
|
||||
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
|
||||
)
|
||||
nvidiaCTKPath := cfg.NVIDIACTKConfig.Path
|
||||
|
||||
mounts, err := discover.NewGraphicsMountsDiscoverer(
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
)
|
||||
|
||||
// Run is an entry point that allows for idiomatic handling of errors
|
||||
@ -76,8 +77,13 @@ func (r rt) Run(argv []string) (rerr error) {
|
||||
r.logger.Infof("Running with config:\n%+v", cfg)
|
||||
}
|
||||
|
||||
driver := root.New(
|
||||
root.WithLogger(r.logger),
|
||||
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
|
||||
)
|
||||
|
||||
r.logger.Debugf("Command line arguments: %v", argv)
|
||||
runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv)
|
||||
runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv, driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err)
|
||||
}
|
||||
|
@ -23,12 +23,13 @@ import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||
)
|
||||
|
||||
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
|
||||
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string) (oci.Runtime, error) {
|
||||
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string, driver *root.Driver) (oci.Runtime, error) {
|
||||
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
|
||||
@ -44,7 +45,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
|
||||
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
|
||||
}
|
||||
|
||||
specModifier, err := newSpecModifier(logger, cfg, ociSpec)
|
||||
specModifier, err := newSpecModifier(logger, cfg, ociSpec, driver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
|
||||
}
|
||||
@ -61,7 +62,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
|
||||
}
|
||||
|
||||
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
|
||||
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
|
||||
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) {
|
||||
rawSpec, err := ociSpec.Load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
|
||||
@ -82,7 +83,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
|
||||
return modeModifier, nil
|
||||
}
|
||||
|
||||
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image)
|
||||
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
|
||||
)
|
||||
|
||||
@ -63,6 +64,9 @@ func TestMain(m *testing.M) {
|
||||
|
||||
func TestFactoryMethod(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
driver := root.New(
|
||||
root.WithDriverRoot("/nvidia/driver/root"),
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
@ -143,6 +147,7 @@ func TestFactoryMethod(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
|
||||
bundleDir := t.TempDir()
|
||||
|
||||
specFile, err := os.Create(filepath.Join(bundleDir, "config.json"))
|
||||
@ -151,7 +156,7 @@ func TestFactoryMethod(t *testing.T) {
|
||||
|
||||
argv := []string{"--bundle", bundleDir, "create"}
|
||||
|
||||
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv)
|
||||
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv, driver)
|
||||
if tc.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user