mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Create root.Driver instance at first usage
This allows for testing through injection of the driver root. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
		
							parent
							
								
									413da20838
								
							
						
					
					
						commit
						011c658945
					
				| @ -29,16 +29,12 @@ import ( | ||||
| 
 | ||||
| // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
 | ||||
| // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
 | ||||
| func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { | ||||
| func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver) (oci.SpecModifier, error) { | ||||
| 	if required, reason := requiresGraphicsModifier(image); !required { | ||||
| 		logger.Infof("No graphics modifier required: %v", reason) | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	driver := root.New( | ||||
| 		root.WithLogger(logger), | ||||
| 		root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), | ||||
| 	) | ||||
| 	nvidiaCTKPath := cfg.NVIDIACTKConfig.Path | ||||
| 
 | ||||
| 	mounts, err := discover.NewGraphicsMountsDiscoverer( | ||||
|  | ||||
| @ -26,6 +26,7 @@ import ( | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/config" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" | ||||
| ) | ||||
| 
 | ||||
| // Run is an entry point that allows for idiomatic handling of errors
 | ||||
| @ -76,8 +77,13 @@ func (r rt) Run(argv []string) (rerr error) { | ||||
| 		r.logger.Infof("Running with config:\n%+v", cfg) | ||||
| 	} | ||||
| 
 | ||||
| 	driver := root.New( | ||||
| 		root.WithLogger(r.logger), | ||||
| 		root.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), | ||||
| 	) | ||||
| 
 | ||||
| 	r.logger.Debugf("Command line arguments: %v", argv) | ||||
| 	runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv) | ||||
| 	runtime, err := newNVIDIAContainerRuntime(r.logger, cfg, argv, driver) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err) | ||||
| 	} | ||||
|  | ||||
| @ -23,12 +23,13 @@ import ( | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" | ||||
| ) | ||||
| 
 | ||||
| // newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
 | ||||
| func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string) (oci.Runtime, error) { | ||||
| func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv []string, driver *root.Driver) (oci.Runtime, error) { | ||||
| 	lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error constructing low-level runtime: %v", err) | ||||
| @ -44,7 +45,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv | ||||
| 		return nil, fmt.Errorf("error constructing OCI specification: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	specModifier, err := newSpecModifier(logger, cfg, ociSpec) | ||||
| 	specModifier, err := newSpecModifier(logger, cfg, ociSpec, driver) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) | ||||
| 	} | ||||
| @ -61,7 +62,7 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv | ||||
| } | ||||
| 
 | ||||
| // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
 | ||||
| func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { | ||||
| func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) { | ||||
| 	rawSpec, err := ociSpec.Load() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to load OCI spec: %v", err) | ||||
| @ -82,7 +83,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp | ||||
| 		return modeModifier, nil | ||||
| 	} | ||||
| 
 | ||||
| 	graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image) | ||||
| 	graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @ -29,6 +29,7 @@ import ( | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/config" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/test" | ||||
| ) | ||||
| 
 | ||||
| @ -63,6 +64,9 @@ func TestMain(m *testing.M) { | ||||
| 
 | ||||
| func TestFactoryMethod(t *testing.T) { | ||||
| 	logger, _ := testlog.NewNullLogger() | ||||
| 	driver := root.New( | ||||
| 		root.WithDriverRoot("/nvidia/driver/root"), | ||||
| 	) | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		description   string | ||||
| @ -143,6 +147,7 @@ func TestFactoryMethod(t *testing.T) { | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.description, func(t *testing.T) { | ||||
| 
 | ||||
| 			bundleDir := t.TempDir() | ||||
| 
 | ||||
| 			specFile, err := os.Create(filepath.Join(bundleDir, "config.json")) | ||||
| @ -151,7 +156,7 @@ func TestFactoryMethod(t *testing.T) { | ||||
| 
 | ||||
| 			argv := []string{"--bundle", bundleDir, "create"} | ||||
| 
 | ||||
| 			_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv) | ||||
| 			_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv, driver) | ||||
| 			if tc.expectedError { | ||||
| 				require.Error(t, err) | ||||
| 			} else { | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user