mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-24 21:14:00 +00:00
Create root.Driver instance at first usage
This allows for testing through injection of the driver root. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
413da20838
commit
011c658945
@ -29,16 +29,12 @@ import (
|
|||||||
|
|
||||||
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
|
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
|
||||||
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
|
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
|
||||||
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
|
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) {
|
||||||
if required, reason := requiresGraphicsModifier(image); !required {
|
if required, reason := requiresGraphicsModifier(image); !required {
|
||||||
logger.Infof("No graphics modifier required: %v", reason)
|
logger.Infof("No graphics modifier required: %v", reason)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
driver := root.New(
|
|
||||||
root.WithLogger(logger),
|
|
||||||
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
|
|
||||||
)
|
|
||||||
nvidiaCTKPath := cfg.NVIDIACTKConfig.Path
|
nvidiaCTKPath := cfg.NVIDIACTKConfig.Path
|
||||||
|
|
||||||
mounts, err := discover.NewGraphicsMountsDiscoverer(
|
mounts, err := discover.NewGraphicsMountsDiscoverer(
|
||||||
|
@ -26,6 +26,7 @@ import (
|
|||||||
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Run is an entry point that allows for idiomatic handling of errors
|
// Run is an entry point that allows for idiomatic handling of errors
|
||||||
@ -76,8 +77,13 @@ func (r rt) Run(argv []string) (rerr error) {
|
|||||||
r.logger.Infof("Running with config:\n%+v", cfg)
|
r.logger.Infof("Running with config:\n%+v", cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
driver := root.New(
|
||||||
|
root.WithLogger(r.logger),
|
||||||
|
root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
|
||||||
|
)
|
||||||
|
|
||||||
r.logger.Debugf("Command line arguments: %v", argv)
|
r.logger.Debugf("Command line arguments: %v", argv)
|
||||||
runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv)
|
runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv, driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err)
|
return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -23,12 +23,13 @@ import (
|
|||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
|
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
|
||||||
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string) (oci.Runtime, error) {
|
func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string, driver *root.Driver) (oci.Runtime, error) {
|
||||||
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes)
|
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
|
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
|
||||||
@ -44,7 +45,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
|
|||||||
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
|
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
specModifier, err := newSpecModifier(logger, cfg, ociSpec)
|
specModifier, err := newSpecModifier(logger, cfg, ociSpec, driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
|
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
|
||||||
}
|
}
|
||||||
@ -61,7 +62,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
|
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
|
||||||
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
|
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) {
|
||||||
rawSpec, err := ociSpec.Load()
|
rawSpec, err := ociSpec.Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
|
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
|
||||||
@ -82,7 +83,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
|
|||||||
return modeModifier, nil
|
return modeModifier, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image)
|
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,6 +64,9 @@ func TestMain(m *testing.M) {
|
|||||||
|
|
||||||
func TestFactoryMethod(t *testing.T) {
|
func TestFactoryMethod(t *testing.T) {
|
||||||
logger, _ := testlog.NewNullLogger()
|
logger, _ := testlog.NewNullLogger()
|
||||||
|
driver := root.New(
|
||||||
|
root.WithDriverRoot("/nvidia/driver/root"),
|
||||||
|
)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
description string
|
description string
|
||||||
@ -143,6 +147,7 @@ func TestFactoryMethod(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.description, func(t *testing.T) {
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
|
||||||
bundleDir := t.TempDir()
|
bundleDir := t.TempDir()
|
||||||
|
|
||||||
specFile, err := os.Create(filepath.Join(bundleDir, "config.json"))
|
specFile, err := os.Create(filepath.Join(bundleDir, "config.json"))
|
||||||
@ -151,7 +156,7 @@ func TestFactoryMethod(t *testing.T) {
|
|||||||
|
|
||||||
argv := []string{"--bundle", bundleDir, "create"}
|
argv := []string{"--bundle", bundleDir, "create"}
|
||||||
|
|
||||||
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv)
|
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv, driver)
|
||||||
if tc.expectedError {
|
if tc.expectedError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user