mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Merge branch 'fix-load-kernel-modules' into 'main'
Split internal system package See merge request nvidia/container-toolkit/container-toolkit!420
This commit is contained in:
		
							parent
							
								
									7c807c2c22
								
							
						
					
					
						commit
						cc06766f25
					
				| @ -6,6 +6,7 @@ | ||||
| * Fix bug causing incorrect nvidia-smi symlink to be created on WSL2 systems with multiple driver roots. | ||||
| * Fix bug when using driver versions that do not include a patch component in their version number. | ||||
| * Skip additional modifications in CDI mode. | ||||
| * Fix loading of kernel modules and creation of device nodes in containerized use cases. | ||||
| 
 | ||||
| * [toolkit-container] Allow same envars for all runtime configs | ||||
| 
 | ||||
|  | ||||
| @ -28,14 +28,14 @@ import ( | ||||
| 
 | ||||
| type allPossible struct { | ||||
| 	logger       *logrus.Logger | ||||
| 	driverRoot   string | ||||
| 	devRoot      string | ||||
| 	deviceMajors devices.Devices | ||||
| 	migCaps      nvcaps.MigCaps | ||||
| } | ||||
| 
 | ||||
| // newAllPossible returns a new allPossible device node lister.
 | ||||
| // This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices.
 | ||||
| func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error) { | ||||
| func newAllPossible(logger *logrus.Logger, devRoot string) (nodeLister, error) { | ||||
| 	deviceMajors, err := devices.GetNVIDIADevices() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed reading device majors: %v", err) | ||||
| @ -61,7 +61,7 @@ func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error | ||||
| 
 | ||||
| 	l := allPossible{ | ||||
| 		logger:       logger, | ||||
| 		driverRoot:   driverRoot, | ||||
| 		devRoot:      devRoot, | ||||
| 		deviceMajors: deviceMajors, | ||||
| 		migCaps:      migCaps, | ||||
| 	} | ||||
| @ -72,7 +72,7 @@ func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error | ||||
| // DeviceNodes returns a list of all possible device nodes for NVIDIA GPUs, control devices, and capability devices.
 | ||||
| func (m allPossible) DeviceNodes() ([]deviceNode, error) { | ||||
| 	gpus, err := nvpci.NewFrom( | ||||
| 		filepath.Join(m.driverRoot, nvpci.PCIDevicesRoot), | ||||
| 		filepath.Join(m.devRoot, nvpci.PCIDevicesRoot), | ||||
| 	).GetGPUs() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to get GPU information: %v", err) | ||||
| @ -80,7 +80,7 @@ func (m allPossible) DeviceNodes() ([]deviceNode, error) { | ||||
| 
 | ||||
| 	count := len(gpus) | ||||
| 	if count == 0 { | ||||
| 		m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot) | ||||
| 		m.logger.Infof("No NVIDIA devices found in %s", m.devRoot) | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| @ -179,7 +179,7 @@ func (m allPossible) newDeviceNode(deviceName devices.Name, path string, minor i | ||||
| 	major, _ := m.deviceMajors.Get(deviceName) | ||||
| 
 | ||||
| 	return deviceNode{ | ||||
| 		path:  filepath.Join(m.driverRoot, path), | ||||
| 		path:  filepath.Join(m.devRoot, path), | ||||
| 		major: uint32(major), | ||||
| 		minor: uint32(minor), | ||||
| 	} | ||||
|  | ||||
| @ -24,7 +24,8 @@ import ( | ||||
| 	"strings" | ||||
| 	"syscall" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules" | ||||
| 	"github.com/fsnotify/fsnotify" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/urfave/cli/v2" | ||||
| @ -216,6 +217,7 @@ type linkCreator struct { | ||||
| 	logger            *logrus.Logger | ||||
| 	lister            nodeLister | ||||
| 	driverRoot        string | ||||
| 	devRoot           string | ||||
| 	devCharPath       string | ||||
| 	dryRun            bool | ||||
| 	createAll         bool | ||||
| @ -243,6 +245,9 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { | ||||
| 	if c.driverRoot == "" { | ||||
| 		c.driverRoot = "/" | ||||
| 	} | ||||
| 	if c.devRoot == "" { | ||||
| 		c.devRoot = "/" | ||||
| 	} | ||||
| 	if c.devCharPath == "" { | ||||
| 		c.devCharPath = defaultDevCharPath | ||||
| 	} | ||||
| @ -252,13 +257,13 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { | ||||
| 	} | ||||
| 
 | ||||
| 	if c.createAll { | ||||
| 		lister, err := newAllPossible(c.logger, c.driverRoot) | ||||
| 		lister, err := newAllPossible(c.logger, c.devRoot) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to create all possible device lister: %v", err) | ||||
| 		} | ||||
| 		c.lister = lister | ||||
| 	} else { | ||||
| 		c.lister = existing{c.logger, c.driverRoot} | ||||
| 		c.lister = existing{c.logger, c.devRoot} | ||||
| 	} | ||||
| 	return c, nil | ||||
| } | ||||
| @ -268,36 +273,48 @@ func (m linkCreator) setup() error { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	s, err := system.New( | ||||
| 		system.WithLogger(m.logger), | ||||
| 		system.WithDryRun(m.dryRun), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if m.loadKernelModules { | ||||
| 		if err := s.LoadNVIDIAKernelModules(); err != nil { | ||||
| 		modules := nvmodules.New( | ||||
| 			nvmodules.WithLogger(m.logger), | ||||
| 			nvmodules.WithDryRun(m.dryRun), | ||||
| 			nvmodules.WithRoot(m.driverRoot), | ||||
| 		) | ||||
| 		if err := modules.LoadAll(); err != nil { | ||||
| 			return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if m.createDeviceNodes { | ||||
| 		if err := s.CreateNVIDIAControlDeviceNodesAt(m.driverRoot); err != nil { | ||||
| 		devices, err := nvdevices.New( | ||||
| 			nvdevices.WithLogger(m.logger), | ||||
| 			nvdevices.WithDryRun(m.dryRun), | ||||
| 			nvdevices.WithDevRoot(m.devRoot), | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if err := devices.CreateNVIDIAControlDevices(); err != nil { | ||||
| 			return fmt.Errorf("failed to create NVIDIA device nodes: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // WithDriverRoot sets the driver root path.
 | ||||
| // This is the path in which kernel modules must be loaded.
 | ||||
| func WithDriverRoot(root string) Option { | ||||
| 	return func(c *linkCreator) { | ||||
| 		c.driverRoot = root | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithDevRoot sets the root path for the /dev directory.
 | ||||
| func WithDevRoot(root string) Option { | ||||
| 	return func(c *linkCreator) { | ||||
| 		c.devRoot = root | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithDevCharPath sets the path at which the symlinks will be created.
 | ||||
| func WithDevCharPath(path string) Option { | ||||
| 	return func(c *linkCreator) { | ||||
|  | ||||
| @ -31,7 +31,7 @@ type nodeLister interface { | ||||
| 
 | ||||
| type existing struct { | ||||
| 	logger  *logrus.Logger | ||||
| 	driverRoot string | ||||
| 	devRoot string | ||||
| } | ||||
| 
 | ||||
| // DeviceNodes returns a list of NVIDIA device nodes in the specified root.
 | ||||
| @ -39,7 +39,7 @@ type existing struct { | ||||
| func (m existing) DeviceNodes() ([]deviceNode, error) { | ||||
| 	locator := lookup.NewCharDeviceLocator( | ||||
| 		lookup.WithLogger(m.logger), | ||||
| 		lookup.WithRoot(m.driverRoot), | ||||
| 		lookup.WithRoot(m.devRoot), | ||||
| 		lookup.WithOptional(true), | ||||
| 	) | ||||
| 
 | ||||
| @ -54,7 +54,7 @@ func (m existing) DeviceNodes() ([]deviceNode, error) { | ||||
| 	} | ||||
| 
 | ||||
| 	if len(devices) == 0 && len(capDevices) == 0 { | ||||
| 		m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot) | ||||
| 		m.logger.Infof("No NVIDIA devices found in %s", m.devRoot) | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -19,7 +19,8 @@ package createdevicenodes | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/urfave/cli/v2" | ||||
| ) | ||||
| @ -96,19 +97,29 @@ func (m command) validateFlags(r *cli.Context, opts *options) error { | ||||
| } | ||||
| 
 | ||||
| func (m command) run(c *cli.Context, opts *options) error { | ||||
| 	s, err := system.New( | ||||
| 		system.WithLogger(m.logger), | ||||
| 		system.WithDryRun(opts.dryRun), | ||||
| 		system.WithLoadKernelModules(opts.loadKernelModules), | ||||
| 	if opts.loadKernelModules { | ||||
| 		modules := nvmodules.New( | ||||
| 			nvmodules.WithLogger(m.logger), | ||||
| 			nvmodules.WithDryRun(opts.dryRun), | ||||
| 			nvmodules.WithRoot(opts.driverRoot), | ||||
| 		) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to create library: %v", err) | ||||
| 		if err := modules.LoadAll(); err != nil { | ||||
| 			return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if opts.control { | ||||
| 		devices, err := nvdevices.New( | ||||
| 			nvdevices.WithLogger(m.logger), | ||||
| 			nvdevices.WithDryRun(opts.dryRun), | ||||
| 			nvdevices.WithDevRoot(opts.driverRoot), | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		m.logger.Infof("Creating control device nodes at %s", opts.driverRoot) | ||||
| 		if err := s.CreateNVIDIAControlDeviceNodesAt(opts.driverRoot); err != nil { | ||||
| 			return fmt.Errorf("failed to create control device nodes: %v", err) | ||||
| 		if err := devices.CreateNVIDIAControlDevices(); err != nil { | ||||
| 			return fmt.Errorf("failed to create NVIDIA control device nodes: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
|  | ||||
							
								
								
									
										154
									
								
								internal/system/nvdevices/devices.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								internal/system/nvdevices/devices.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,154 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATIOm.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvdevices | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| var errInvalidDeviceNode = errors.New("invalid device node") | ||||
| 
 | ||||
| // Interface provides a set of utilities for interacting with NVIDIA devices on the system.
 | ||||
| type Interface struct { | ||||
| 	devices.Devices | ||||
| 
 | ||||
| 	logger *logrus.Logger | ||||
| 
 | ||||
| 	dryRun bool | ||||
| 	// devRoot is the root directory where device nodes are expected to exist.
 | ||||
| 	devRoot string | ||||
| 
 | ||||
| 	mknoder | ||||
| } | ||||
| 
 | ||||
| // New constructs a new Interface struct with the specified options.
 | ||||
| func New(opts ...Option) (*Interface, error) { | ||||
| 	i := &Interface{} | ||||
| 	for _, opt := range opts { | ||||
| 		opt(i) | ||||
| 	} | ||||
| 
 | ||||
| 	if i.logger == nil { | ||||
| 		i.logger = logrus.StandardLogger() | ||||
| 	} | ||||
| 	if i.devRoot == "" { | ||||
| 		i.devRoot = "/" | ||||
| 	} | ||||
| 	if i.Devices == nil { | ||||
| 		devices, err := devices.GetNVIDIADevices() | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to create devices info: %v", err) | ||||
| 		} | ||||
| 		i.Devices = devices | ||||
| 	} | ||||
| 
 | ||||
| 	if i.dryRun { | ||||
| 		i.mknoder = &mknodLogger{i.logger} | ||||
| 	} else { | ||||
| 		i.mknoder = &mknodUnix{} | ||||
| 	} | ||||
| 	return i, nil | ||||
| } | ||||
| 
 | ||||
| // CreateNVIDIAControlDevices creates the NVIDIA control device nodes at the configured devRoot.
 | ||||
| func (m *Interface) CreateNVIDIAControlDevices() error { | ||||
| 	controlNodes := []string{"nvidiactl", "nvidia-modeset", "nvidia-uvm", "nvidia-uvm-tools"} | ||||
| 	for _, node := range controlNodes { | ||||
| 		err := m.CreateNVIDIADevice(node) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("failed to create device node %s: %w", node, err) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // CreateNVIDIADevice creates the specified NVIDIA device node at the configured devRoot.
 | ||||
| func (m *Interface) CreateNVIDIADevice(node string) error { | ||||
| 	node = filepath.Base(node) | ||||
| 	if !strings.HasPrefix(node, "nvidia") { | ||||
| 		return fmt.Errorf("invalid device node %q: %w", node, errInvalidDeviceNode) | ||||
| 	} | ||||
| 
 | ||||
| 	major, err := m.Major(node) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to determine major: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	minor, err := m.Minor(node) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to determine minor: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return m.createDeviceNode(filepath.Join("dev", node), int(major), int(minor)) | ||||
| } | ||||
| 
 | ||||
| // createDeviceNode creates the specified device node with the require major and minor numbers.
 | ||||
| // If a devRoot is configured, this is prepended to the path.
 | ||||
| func (m *Interface) createDeviceNode(path string, major int, minor int) error { | ||||
| 	path = filepath.Join(m.devRoot, path) | ||||
| 	if _, err := os.Stat(path); err == nil { | ||||
| 		m.logger.Infof("Skipping: %s already exists", path) | ||||
| 		return nil | ||||
| 	} else if !os.IsNotExist(err) { | ||||
| 		return fmt.Errorf("failed to stat %s: %v", path, err) | ||||
| 	} | ||||
| 
 | ||||
| 	return m.Mknode(path, major, minor) | ||||
| } | ||||
| 
 | ||||
| // Major returns the major number for the specified NVIDIA device node.
 | ||||
| // If the device node is not supported, an error is returned.
 | ||||
| func (m *Interface) Major(node string) (int64, error) { | ||||
| 	var valid bool | ||||
| 	var major devices.Major | ||||
| 	switch node { | ||||
| 	case "nvidia-uvm", "nvidia-uvm-tools": | ||||
| 		major, valid = m.Get(devices.NVIDIAUVM) | ||||
| 	case "nvidia-modeset", "nvidiactl": | ||||
| 		major, valid = m.Get(devices.NVIDIAGPU) | ||||
| 	} | ||||
| 
 | ||||
| 	if valid { | ||||
| 		return int64(major), nil | ||||
| 	} | ||||
| 
 | ||||
| 	return 0, errInvalidDeviceNode | ||||
| } | ||||
| 
 | ||||
| // Minor returns the minor number for the specified NVIDIA device node.
 | ||||
| // If the device node is not supported, an error is returned.
 | ||||
| func (m *Interface) Minor(node string) (int64, error) { | ||||
| 	switch node { | ||||
| 	case "nvidia-modeset": | ||||
| 		return devices.NVIDIAModesetMinor, nil | ||||
| 	case "nvidia-uvm-tools": | ||||
| 		return devices.NVIDIAUVMToolsMinor, nil | ||||
| 	case "nvidia-uvm": | ||||
| 		return devices.NVIDIAUVMMinor, nil | ||||
| 	case "nvidiactl": | ||||
| 		return devices.NVIDIACTLMinor, nil | ||||
| 	} | ||||
| 
 | ||||
| 	return 0, errInvalidDeviceNode | ||||
| } | ||||
							
								
								
									
										133
									
								
								internal/system/nvdevices/devices_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								internal/system/nvdevices/devices_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,133 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvdevices | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" | ||||
| 	testlog "github.com/sirupsen/logrus/hooks/test" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestCreateControlDevices(t *testing.T) { | ||||
| 	logger, _ := testlog.NewNullLogger() | ||||
| 
 | ||||
| 	nvidiaDevices := &devices.DevicesMock{ | ||||
| 		GetFunc: func(name devices.Name) (devices.Major, bool) { | ||||
| 			devices := map[devices.Name]devices.Major{ | ||||
| 				"nvidia-frontend": 195, | ||||
| 				"nvidia-uvm":      243, | ||||
| 			} | ||||
| 			return devices[name], true | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	mknodeError := errors.New("mknode error") | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		description   string | ||||
| 		root          string | ||||
| 		devices       devices.Devices | ||||
| 		mknodeError   error | ||||
| 		expectedError error | ||||
| 		expectedCalls []struct { | ||||
| 			S  string | ||||
| 			N1 int | ||||
| 			N2 int | ||||
| 		} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			description: "no root specified", | ||||
| 			root:        "", | ||||
| 			devices:     nvidiaDevices, | ||||
| 			mknodeError: nil, | ||||
| 			expectedCalls: []struct { | ||||
| 				S  string | ||||
| 				N1 int | ||||
| 				N2 int | ||||
| 			}{ | ||||
| 				{"/dev/nvidiactl", 195, 255}, | ||||
| 				{"/dev/nvidia-modeset", 195, 254}, | ||||
| 				{"/dev/nvidia-uvm", 243, 0}, | ||||
| 				{"/dev/nvidia-uvm-tools", 243, 1}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "some root specified", | ||||
| 			root:        "/some/root", | ||||
| 			devices:     nvidiaDevices, | ||||
| 			mknodeError: nil, | ||||
| 			expectedCalls: []struct { | ||||
| 				S  string | ||||
| 				N1 int | ||||
| 				N2 int | ||||
| 			}{ | ||||
| 				{"/some/root/dev/nvidiactl", 195, 255}, | ||||
| 				{"/some/root/dev/nvidia-modeset", 195, 254}, | ||||
| 				{"/some/root/dev/nvidia-uvm", 243, 0}, | ||||
| 				{"/some/root/dev/nvidia-uvm-tools", 243, 1}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "mknod error returns error", | ||||
| 			devices:       nvidiaDevices, | ||||
| 			mknodeError:   mknodeError, | ||||
| 			expectedError: mknodeError, | ||||
| 			// We expect the first call to this to fail, and the rest to be skipped
 | ||||
| 			expectedCalls: []struct { | ||||
| 				S  string | ||||
| 				N1 int | ||||
| 				N2 int | ||||
| 			}{ | ||||
| 				{"/dev/nvidiactl", 195, 255}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "missing major returns error", | ||||
| 			devices: &devices.DevicesMock{ | ||||
| 				GetFunc: func(name devices.Name) (devices.Major, bool) { | ||||
| 					return 0, false | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedError: errInvalidDeviceNode, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.description, func(t *testing.T) { | ||||
| 			mknode := &mknoderMock{ | ||||
| 				MknodeFunc: func(string, int, int) error { | ||||
| 					return tc.mknodeError | ||||
| 				}, | ||||
| 			} | ||||
| 
 | ||||
| 			d, _ := New( | ||||
| 				WithLogger(logger), | ||||
| 				WithDevRoot(tc.root), | ||||
| 				WithDevices(tc.devices), | ||||
| 			) | ||||
| 			d.mknoder = mknode | ||||
| 
 | ||||
| 			err := d.CreateNVIDIAControlDevices() | ||||
| 			require.ErrorIs(t, err, tc.expectedError) | ||||
| 			require.EqualValues(t, tc.expectedCalls, mknode.MknodeCalls()) | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
							
								
								
									
										46
									
								
								internal/system/nvdevices/mknod.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								internal/system/nvdevices/mknod.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,46 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvdevices | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"golang.org/x/sys/unix" | ||||
| ) | ||||
| 
 | ||||
| //go:generate moq -stub -out mknod_mock.go . mknoder
 | ||||
| type mknoder interface { | ||||
| 	Mknode(string, int, int) error | ||||
| } | ||||
| 
 | ||||
| type mknodLogger struct { | ||||
| 	*logrus.Logger | ||||
| } | ||||
| 
 | ||||
| func (m *mknodLogger) Mknode(path string, major, minor int) error { | ||||
| 	m.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type mknodUnix struct{} | ||||
| 
 | ||||
| func (m *mknodUnix) Mknode(path string, major, minor int) error { | ||||
| 	err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor)))) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return unix.Chmod(path, 0666) | ||||
| } | ||||
							
								
								
									
										89
									
								
								internal/system/nvdevices/mknod_mock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								internal/system/nvdevices/mknod_mock.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,89 @@ | ||||
| // Code generated by moq; DO NOT EDIT.
 | ||||
| // github.com/matryer/moq
 | ||||
| 
 | ||||
| package nvdevices | ||||
| 
 | ||||
| import ( | ||||
| 	"sync" | ||||
| ) | ||||
| 
 | ||||
| // Ensure, that mknoderMock does implement mknoder.
 | ||||
| // If this is not the case, regenerate this file with moq.
 | ||||
| var _ mknoder = &mknoderMock{} | ||||
| 
 | ||||
| // mknoderMock is a mock implementation of mknoder.
 | ||||
| //
 | ||||
| //	func TestSomethingThatUsesmknoder(t *testing.T) {
 | ||||
| //
 | ||||
| //		// make and configure a mocked mknoder
 | ||||
| //		mockedmknoder := &mknoderMock{
 | ||||
| //			MknodeFunc: func(s string, n1 int, n2 int) error {
 | ||||
| //				panic("mock out the Mknode method")
 | ||||
| //			},
 | ||||
| //		}
 | ||||
| //
 | ||||
| //		// use mockedmknoder in code that requires mknoder
 | ||||
| //		// and then make assertions.
 | ||||
| //
 | ||||
| //	}
 | ||||
| type mknoderMock struct { | ||||
| 	// MknodeFunc mocks the Mknode method.
 | ||||
| 	MknodeFunc func(s string, n1 int, n2 int) error | ||||
| 
 | ||||
| 	// calls tracks calls to the methods.
 | ||||
| 	calls struct { | ||||
| 		// Mknode holds details about calls to the Mknode method.
 | ||||
| 		Mknode []struct { | ||||
| 			// S is the s argument value.
 | ||||
| 			S string | ||||
| 			// N1 is the n1 argument value.
 | ||||
| 			N1 int | ||||
| 			// N2 is the n2 argument value.
 | ||||
| 			N2 int | ||||
| 		} | ||||
| 	} | ||||
| 	lockMknode sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| // Mknode calls MknodeFunc.
 | ||||
| func (mock *mknoderMock) Mknode(s string, n1 int, n2 int) error { | ||||
| 	callInfo := struct { | ||||
| 		S  string | ||||
| 		N1 int | ||||
| 		N2 int | ||||
| 	}{ | ||||
| 		S:  s, | ||||
| 		N1: n1, | ||||
| 		N2: n2, | ||||
| 	} | ||||
| 	mock.lockMknode.Lock() | ||||
| 	mock.calls.Mknode = append(mock.calls.Mknode, callInfo) | ||||
| 	mock.lockMknode.Unlock() | ||||
| 	if mock.MknodeFunc == nil { | ||||
| 		var ( | ||||
| 			errOut error | ||||
| 		) | ||||
| 		return errOut | ||||
| 	} | ||||
| 	return mock.MknodeFunc(s, n1, n2) | ||||
| } | ||||
| 
 | ||||
| // MknodeCalls gets all the calls that were made to Mknode.
 | ||||
| // Check the length with:
 | ||||
| //
 | ||||
| //	len(mockedmknoder.MknodeCalls())
 | ||||
| func (mock *mknoderMock) MknodeCalls() []struct { | ||||
| 	S  string | ||||
| 	N1 int | ||||
| 	N2 int | ||||
| } { | ||||
| 	var calls []struct { | ||||
| 		S  string | ||||
| 		N1 int | ||||
| 		N2 int | ||||
| 	} | ||||
| 	mock.lockMknode.RLock() | ||||
| 	calls = mock.calls.Mknode | ||||
| 	mock.lockMknode.RUnlock() | ||||
| 	return calls | ||||
| } | ||||
							
								
								
									
										53
									
								
								internal/system/nvdevices/options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								internal/system/nvdevices/options.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvdevices | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| // Option is a function that sets an option on the Interface struct.
 | ||||
| type Option func(*Interface) | ||||
| 
 | ||||
| // WithDryRun sets the dry run option for the Interface struct.
 | ||||
| func WithDryRun(dryRun bool) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.dryRun = dryRun | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithLogger sets the logger for the Interface struct.
 | ||||
| func WithLogger(logger *logrus.Logger) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.logger = logger | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithDevRoot sets the root directory for the NVIDIA device nodes.
 | ||||
| func WithDevRoot(devRoot string) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.devRoot = devRoot | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithDevices sets the devices for the Interface struct.
 | ||||
| func WithDevices(devices devices.Devices) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.Devices = devices | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										49
									
								
								internal/system/nvmodules/cmd.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								internal/system/nvmodules/cmd.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvmodules | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os/exec" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| //go:generate moq -stub -out cmd_mock.go . cmder
 | ||||
| type cmder interface { | ||||
| 	// Run executes the command and returns the stdout, stderr, and an error if any
 | ||||
| 	Run(string, ...string) error | ||||
| } | ||||
| 
 | ||||
| type cmderLogger struct { | ||||
| 	*logrus.Logger | ||||
| } | ||||
| 
 | ||||
| func (c *cmderLogger) Run(cmd string, args ...string) error { | ||||
| 	c.Infof("Running: %v %v", cmd, strings.Join(args, " ")) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type cmderExec struct{} | ||||
| 
 | ||||
| func (c *cmderExec) Run(cmd string, args ...string) error { | ||||
| 	if output, err := exec.Command(cmd, args...).CombinedOutput(); err != nil { | ||||
| 		return fmt.Errorf("%w; output=%v", err, string(output)) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										83
									
								
								internal/system/nvmodules/cmd_mock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								internal/system/nvmodules/cmd_mock.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,83 @@ | ||||
| // Code generated by moq; DO NOT EDIT.
 | ||||
| // github.com/matryer/moq
 | ||||
| 
 | ||||
| package nvmodules | ||||
| 
 | ||||
| import ( | ||||
| 	"sync" | ||||
| ) | ||||
| 
 | ||||
| // Ensure, that cmderMock does implement cmder.
 | ||||
| // If this is not the case, regenerate this file with moq.
 | ||||
| var _ cmder = &cmderMock{} | ||||
| 
 | ||||
| // cmderMock is a mock implementation of cmder.
 | ||||
| //
 | ||||
| //	func TestSomethingThatUsescmder(t *testing.T) {
 | ||||
| //
 | ||||
| //		// make and configure a mocked cmder
 | ||||
| //		mockedcmder := &cmderMock{
 | ||||
| //			RunFunc: func(s string, strings ...string) error {
 | ||||
| //				panic("mock out the Run method")
 | ||||
| //			},
 | ||||
| //		}
 | ||||
| //
 | ||||
| //		// use mockedcmder in code that requires cmder
 | ||||
| //		// and then make assertions.
 | ||||
| //
 | ||||
| //	}
 | ||||
| type cmderMock struct { | ||||
| 	// RunFunc mocks the Run method.
 | ||||
| 	RunFunc func(s string, strings ...string) error | ||||
| 
 | ||||
| 	// calls tracks calls to the methods.
 | ||||
| 	calls struct { | ||||
| 		// Run holds details about calls to the Run method.
 | ||||
| 		Run []struct { | ||||
| 			// S is the s argument value.
 | ||||
| 			S string | ||||
| 			// Strings is the strings argument value.
 | ||||
| 			Strings []string | ||||
| 		} | ||||
| 	} | ||||
| 	lockRun sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| // Run calls RunFunc.
 | ||||
| func (mock *cmderMock) Run(s string, strings ...string) error { | ||||
| 	callInfo := struct { | ||||
| 		S       string | ||||
| 		Strings []string | ||||
| 	}{ | ||||
| 		S:       s, | ||||
| 		Strings: strings, | ||||
| 	} | ||||
| 	mock.lockRun.Lock() | ||||
| 	mock.calls.Run = append(mock.calls.Run, callInfo) | ||||
| 	mock.lockRun.Unlock() | ||||
| 	if mock.RunFunc == nil { | ||||
| 		var ( | ||||
| 			errOut error | ||||
| 		) | ||||
| 		return errOut | ||||
| 	} | ||||
| 	return mock.RunFunc(s, strings...) | ||||
| } | ||||
| 
 | ||||
| // RunCalls gets all the calls that were made to Run.
 | ||||
| // Check the length with:
 | ||||
| //
 | ||||
| //	len(mockedcmder.RunCalls())
 | ||||
| func (mock *cmderMock) RunCalls() []struct { | ||||
| 	S       string | ||||
| 	Strings []string | ||||
| } { | ||||
| 	var calls []struct { | ||||
| 		S       string | ||||
| 		Strings []string | ||||
| 	} | ||||
| 	mock.lockRun.RLock() | ||||
| 	calls = mock.calls.Run | ||||
| 	mock.lockRun.RUnlock() | ||||
| 	return calls | ||||
| } | ||||
							
								
								
									
										93
									
								
								internal/system/nvmodules/modules.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								internal/system/nvmodules/modules.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,93 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvmodules | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| // Interface provides a set of utilities for interacting with NVIDIA modules on the system.
 | ||||
| type Interface struct { | ||||
| 	logger *logrus.Logger | ||||
| 
 | ||||
| 	dryRun bool | ||||
| 	root   string | ||||
| 
 | ||||
| 	cmder | ||||
| } | ||||
| 
 | ||||
| // New constructs a new Interface struct with the specified options.
 | ||||
| func New(opts ...Option) *Interface { | ||||
| 	m := &Interface{} | ||||
| 	for _, opt := range opts { | ||||
| 		opt(m) | ||||
| 	} | ||||
| 
 | ||||
| 	if m.logger == nil { | ||||
| 		m.logger = logrus.StandardLogger() | ||||
| 	} | ||||
| 	if m.root == "" { | ||||
| 		m.root = "/" | ||||
| 	} | ||||
| 
 | ||||
| 	if m.dryRun { | ||||
| 		m.cmder = &cmderLogger{m.logger} | ||||
| 	} else { | ||||
| 		m.cmder = &cmderExec{} | ||||
| 	} | ||||
| 	return m | ||||
| } | ||||
| 
 | ||||
| // LoadAll loads all the NVIDIA kernel modules.
 | ||||
| func (m *Interface) LoadAll() error { | ||||
| 	modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"} | ||||
| 
 | ||||
| 	for _, module := range modules { | ||||
| 		if err := m.Load(module); err != nil { | ||||
| 			return fmt.Errorf("failed to load module %s: %w", module, err) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| var errInvalidModule = fmt.Errorf("invalid module") | ||||
| 
 | ||||
| // Load loads the specified NVIDIA kernel module.
 | ||||
| // If the root is specified we first chroot into this root.
 | ||||
| func (m *Interface) Load(module string) error { | ||||
| 	if !strings.HasPrefix(module, "nvidia") { | ||||
| 		return errInvalidModule | ||||
| 	} | ||||
| 
 | ||||
| 	var args []string | ||||
| 	if m.root != "/" { | ||||
| 		args = append(args, "chroot", m.root) | ||||
| 	} | ||||
| 	args = append(args, "/sbin/modprobe", module) | ||||
| 
 | ||||
| 	m.logger.Debugf("Loading kernel module %s: %v", module, args) | ||||
| 	err := m.Run(args[0], args[1:]...) | ||||
| 	if err != nil { | ||||
| 		m.logger.Debugf("Failed to load kernel module %s: %v", module, err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										178
									
								
								internal/system/nvmodules/modules_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								internal/system/nvmodules/modules_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,178 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvmodules | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	testlog "github.com/sirupsen/logrus/hooks/test" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestLoadAll(t *testing.T) { | ||||
| 	logger, _ := testlog.NewNullLogger() | ||||
| 
 | ||||
| 	runError := errors.New("run error") | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		description   string | ||||
| 		root          string | ||||
| 		runError      error | ||||
| 		expectedError error | ||||
| 		expectedCalls []struct { | ||||
| 			S       string | ||||
| 			Strings []string | ||||
| 		} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			description: "no root specified", | ||||
| 			root:        "", | ||||
| 			expectedCalls: []struct { | ||||
| 				S       string | ||||
| 				Strings []string | ||||
| 			}{ | ||||
| 				{"/sbin/modprobe", []string{"nvidia"}}, | ||||
| 				{"/sbin/modprobe", []string{"nvidia-uvm"}}, | ||||
| 				{"/sbin/modprobe", []string{"nvidia-modeset"}}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "root causes chroot", | ||||
| 			root:        "/some/root", | ||||
| 			expectedCalls: []struct { | ||||
| 				S       string | ||||
| 				Strings []string | ||||
| 			}{ | ||||
| 				{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}}, | ||||
| 				{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-uvm"}}, | ||||
| 				{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-modeset"}}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "run failure is returned", | ||||
| 			root:          "", | ||||
| 			runError:      runError, | ||||
| 			expectedError: runError, | ||||
| 			expectedCalls: []struct { | ||||
| 				S       string | ||||
| 				Strings []string | ||||
| 			}{ | ||||
| 				{"/sbin/modprobe", []string{"nvidia"}}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.description, func(t *testing.T) { | ||||
| 			cmder := &cmderMock{ | ||||
| 				RunFunc: func(cmd string, args ...string) error { | ||||
| 					return tc.runError | ||||
| 				}, | ||||
| 			} | ||||
| 			m := New( | ||||
| 				WithLogger(logger), | ||||
| 				WithRoot(tc.root), | ||||
| 			) | ||||
| 			m.cmder = cmder | ||||
| 
 | ||||
| 			err := m.LoadAll() | ||||
| 			require.ErrorIs(t, err, tc.expectedError) | ||||
| 
 | ||||
| 			require.EqualValues(t, tc.expectedCalls, cmder.RunCalls()) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestLoad(t *testing.T) { | ||||
| 	logger, _ := testlog.NewNullLogger() | ||||
| 
 | ||||
| 	runError := errors.New("run error") | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		description   string | ||||
| 		root          string | ||||
| 		module        string | ||||
| 		runError      error | ||||
| 		expectedError error | ||||
| 		expectedCalls []struct { | ||||
| 			S       string | ||||
| 			Strings []string | ||||
| 		} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			description: "no root specified", | ||||
| 			root:        "", | ||||
| 			module:      "nvidia", | ||||
| 			expectedCalls: []struct { | ||||
| 				S       string | ||||
| 				Strings []string | ||||
| 			}{ | ||||
| 				{"/sbin/modprobe", []string{"nvidia"}}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "root causes chroot", | ||||
| 			root:        "/some/root", | ||||
| 			module:      "nvidia", | ||||
| 			expectedCalls: []struct { | ||||
| 				S       string | ||||
| 				Strings []string | ||||
| 			}{ | ||||
| 				{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "run failure is returned", | ||||
| 			root:          "", | ||||
| 			module:        "nvidia", | ||||
| 			runError:      runError, | ||||
| 			expectedError: runError, | ||||
| 			expectedCalls: []struct { | ||||
| 				S       string | ||||
| 				Strings []string | ||||
| 			}{ | ||||
| 				{"/sbin/modprobe", []string{"nvidia"}}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			description:   "module prefis is checked", | ||||
| 			module:        "not-nvidia", | ||||
| 			expectedError: errInvalidModule, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.description, func(t *testing.T) { | ||||
| 			cmder := &cmderMock{ | ||||
| 				RunFunc: func(cmd string, args ...string) error { | ||||
| 					return tc.runError | ||||
| 				}, | ||||
| 			} | ||||
| 			m := New( | ||||
| 				WithLogger(logger), | ||||
| 				WithRoot(tc.root), | ||||
| 			) | ||||
| 			m.cmder = cmder | ||||
| 
 | ||||
| 			err := m.Load(tc.module) | ||||
| 			require.ErrorIs(t, err, tc.expectedError) | ||||
| 
 | ||||
| 			require.EqualValues(t, tc.expectedCalls, cmder.RunCalls()) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @ -14,30 +14,30 @@ | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package system | ||||
| package nvmodules | ||||
| 
 | ||||
| import "github.com/sirupsen/logrus" | ||||
| 
 | ||||
| // Option is a functional option for the system command
 | ||||
| // Option is a function that sets an option on the Interface struct.
 | ||||
| type Option func(*Interface) | ||||
| 
 | ||||
| // WithLogger sets the logger for the system command
 | ||||
| func WithLogger(logger *logrus.Logger) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.logger = logger | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithDryRun sets the dry run flag
 | ||||
| // WithDryRun sets the dry run option for the Interface struct.
 | ||||
| func WithDryRun(dryRun bool) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.dryRun = dryRun | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithLoadKernelModules sets the load kernel modules flag
 | ||||
| func WithLoadKernelModules(loadKernelModules bool) Option { | ||||
| // WithLogger sets the logger for the Interface struct.
 | ||||
| func WithLogger(logger *logrus.Logger) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.loadKernelModules = loadKernelModules | ||||
| 		i.logger = logger | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithRoot sets the root directory for the NVIDIA device nodes.
 | ||||
| func WithRoot(root string) Option { | ||||
| 	return func(i *Interface) { | ||||
| 		i.root = root | ||||
| 	} | ||||
| } | ||||
| @ -1,176 +0,0 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package system | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"golang.org/x/sys/unix" | ||||
| ) | ||||
| 
 | ||||
| // Interface is the interface for the system command
 | ||||
| type Interface struct { | ||||
| 	logger            *logrus.Logger | ||||
| 	dryRun            bool | ||||
| 	loadKernelModules bool | ||||
| 	nvidiaDevices     nvidiaDevices | ||||
| } | ||||
| 
 | ||||
| // New constructs a system command with the specified options
 | ||||
| func New(opts ...Option) (*Interface, error) { | ||||
| 	i := &Interface{ | ||||
| 		logger: logrus.StandardLogger(), | ||||
| 	} | ||||
| 	for _, opt := range opts { | ||||
| 		opt(i) | ||||
| 	} | ||||
| 
 | ||||
| 	if i.loadKernelModules { | ||||
| 		if err := i.LoadNVIDIAKernelModules(); err != nil { | ||||
| 			return nil, fmt.Errorf("failed to load kernel modules: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	devices, err := devices.GetNVIDIADevices() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create devices info: %v", err) | ||||
| 	} | ||||
| 	i.nvidiaDevices = nvidiaDevices{devices} | ||||
| 
 | ||||
| 	return i, nil | ||||
| } | ||||
| 
 | ||||
| // CreateNVIDIAControlDeviceNodesAt creates the NVIDIA control device nodes associated with the NVIDIA driver at the specified root.
 | ||||
| func (m *Interface) CreateNVIDIAControlDeviceNodesAt(root string) error { | ||||
| 	controlNodes := []string{"/dev/nvidiactl", "/dev/nvidia-modeset", "/dev/nvidia-uvm", "/dev/nvidia-uvm-tools"} | ||||
| 
 | ||||
| 	for _, node := range controlNodes { | ||||
| 		path := filepath.Join(root, node) | ||||
| 		err := m.CreateNVIDIADeviceNode(path) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("failed to create device node %s: %v", path, err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // CreateNVIDIADeviceNode creates a specified device node associated with the NVIDIA driver.
 | ||||
| func (m *Interface) CreateNVIDIADeviceNode(path string) error { | ||||
| 	node := filepath.Base(path) | ||||
| 	if !strings.HasPrefix(node, "nvidia") { | ||||
| 		return fmt.Errorf("invalid device node %q", node) | ||||
| 	} | ||||
| 
 | ||||
| 	major, err := m.nvidiaDevices.Major(node) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to determine major: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	minor, err := m.nvidiaDevices.Minor(node) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to determine minor: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return m.createDeviceNode(path, int(major), int(minor)) | ||||
| } | ||||
| 
 | ||||
| func (m *Interface) createDeviceNode(path string, major int, minor int) error { | ||||
| 	if m.dryRun { | ||||
| 		m.logger.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor) | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	if _, err := os.Stat(path); err == nil { | ||||
| 		m.logger.Infof("Skipping: %s already exists", path) | ||||
| 		return nil | ||||
| 	} else if !os.IsNotExist(err) { | ||||
| 		return fmt.Errorf("failed to stat %s: %v", path, err) | ||||
| 	} | ||||
| 
 | ||||
| 	err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor)))) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return unix.Chmod(path, 0666) | ||||
| } | ||||
| 
 | ||||
| // LoadNVIDIAKernelModules loads the NVIDIA kernel modules.
 | ||||
| func (m *Interface) LoadNVIDIAKernelModules() error { | ||||
| 	modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"} | ||||
| 
 | ||||
| 	for _, module := range modules { | ||||
| 		if m.dryRun { | ||||
| 			m.logger.Infof("Running: /sbin/modprobe %s", module) | ||||
| 			continue | ||||
| 		} | ||||
| 		cmd := exec.Command("/sbin/modprobe", module) | ||||
| 
 | ||||
| 		if output, err := cmd.CombinedOutput(); err != nil { | ||||
| 			m.logger.Debugf("Failed to load kernel module %s: %v", module, string(output)) | ||||
| 			return fmt.Errorf("failed to load kernel module %s: %v", module, err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type nvidiaDevices struct { | ||||
| 	devices.Devices | ||||
| } | ||||
| 
 | ||||
| // Major returns the major number for the specified NVIDIA device node.
 | ||||
| // If the device node is not supported, an error is returned.
 | ||||
| func (n *nvidiaDevices) Major(node string) (int64, error) { | ||||
| 	var valid bool | ||||
| 	var major devices.Major | ||||
| 	switch node { | ||||
| 	case "nvidia-uvm", "nvidia-uvm-tools": | ||||
| 		major, valid = n.Get(devices.NVIDIAUVM) | ||||
| 	case "nvidia-modeset", "nvidiactl": | ||||
| 		major, valid = n.Get(devices.NVIDIAGPU) | ||||
| 	} | ||||
| 
 | ||||
| 	if !valid { | ||||
| 		return 0, fmt.Errorf("invalid device node %q", node) | ||||
| 	} | ||||
| 
 | ||||
| 	return int64(major), nil | ||||
| } | ||||
| 
 | ||||
| // Minor returns the minor number for the specified NVIDIA device node.
 | ||||
| // If the device node is not supported, an error is returned.
 | ||||
| func (n *nvidiaDevices) Minor(node string) (int64, error) { | ||||
| 	switch node { | ||||
| 	case "nvidia-modeset": | ||||
| 		return devices.NVIDIAModesetMinor, nil | ||||
| 	case "nvidia-uvm-tools": | ||||
| 		return devices.NVIDIAUVMToolsMinor, nil | ||||
| 	case "nvidia-uvm": | ||||
| 		return devices.NVIDIAUVMMinor, nil | ||||
| 	case "nvidiactl": | ||||
| 		return devices.NVIDIACTLMinor, nil | ||||
| 	} | ||||
| 
 | ||||
| 	return 0, fmt.Errorf("invalid device node %q", node) | ||||
| } | ||||
| @ -23,7 +23,7 @@ import ( | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" | ||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" | ||||
| 	"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" | ||||
| @ -683,11 +683,13 @@ func generateCDISpec(opts *options, nvidiaCTKPath string) error { | ||||
| 	} | ||||
| 
 | ||||
| 	log.Infof("Creating control device nodes at %v", opts.DriverRootCtrPath) | ||||
| 	s, err := system.New() | ||||
| 	devices, err := nvdevices.New( | ||||
| 		nvdevices.WithDevRoot(opts.DriverRootCtrPath), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to create library: %v", err) | ||||
| 	} | ||||
| 	if err := s.CreateNVIDIAControlDeviceNodesAt(opts.DriverRootCtrPath); err != nil { | ||||
| 	if err := devices.CreateNVIDIAControlDevices(); err != nil { | ||||
| 		return fmt.Errorf("failed to create control device nodes: %v", err) | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user