diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go index 92ddc5f9..76acee41 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go +++ b/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go @@ -28,14 +28,14 @@ import ( type allPossible struct { logger logger.Interface - driverRoot string + devRoot string deviceMajors devices.Devices migCaps nvcaps.MigCaps } // newAllPossible returns a new allPossible device node lister. // This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices. -func newAllPossible(logger logger.Interface, driverRoot string) (nodeLister, error) { +func newAllPossible(logger logger.Interface, devRoot string) (nodeLister, error) { deviceMajors, err := devices.GetNVIDIADevices() if err != nil { return nil, fmt.Errorf("failed reading device majors: %v", err) @@ -61,7 +61,7 @@ func newAllPossible(logger logger.Interface, driverRoot string) (nodeLister, err l := allPossible{ logger: logger, - driverRoot: driverRoot, + devRoot: devRoot, deviceMajors: deviceMajors, migCaps: migCaps, } @@ -72,7 +72,7 @@ func newAllPossible(logger logger.Interface, driverRoot string) (nodeLister, err // DeviceNodes returns a list of all possible device nodes for NVIDIA GPUs, control devices, and capability devices. func (m allPossible) DeviceNodes() ([]deviceNode, error) { gpus, err := nvpci.New( - nvpci.WithPCIDevicesRoot(filepath.Join(m.driverRoot, nvpci.PCIDevicesRoot)), + nvpci.WithPCIDevicesRoot(filepath.Join(m.devRoot, nvpci.PCIDevicesRoot)), ).GetGPUs() if err != nil { return nil, fmt.Errorf("failed to get GPU information: %v", err) @@ -80,7 +80,7 @@ func (m allPossible) DeviceNodes() ([]deviceNode, error) { count := len(gpus) if count == 0 { - m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot) + m.logger.Infof("No NVIDIA devices found in %s", m.devRoot) return nil, nil } @@ -179,7 +179,7 @@ func (m allPossible) newDeviceNode(deviceName devices.Name, path string, minor i major, _ := m.deviceMajors.Get(deviceName) return deviceNode{ - path: filepath.Join(m.driverRoot, path), + path: filepath.Join(m.devRoot, path), major: uint32(major), minor: uint32(minor), } diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go index 9d0c6a8d..ed6455bd 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go +++ b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go @@ -25,7 +25,8 @@ import ( "syscall" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system" + "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" + "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules" "github.com/fsnotify/fsnotify" "github.com/urfave/cli/v2" ) @@ -216,6 +217,7 @@ type linkCreator struct { logger logger.Interface lister nodeLister driverRoot string + devRoot string devCharPath string dryRun bool createAll bool @@ -243,6 +245,9 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { if c.driverRoot == "" { c.driverRoot = "/" } + if c.devRoot == "" { + c.devRoot = "/" + } if c.devCharPath == "" { c.devCharPath = defaultDevCharPath } @@ -252,13 +257,13 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { } if c.createAll { - lister, err := newAllPossible(c.logger, c.driverRoot) + lister, err := newAllPossible(c.logger, c.devRoot) if err != nil { return nil, fmt.Errorf("failed to create all possible device lister: %v", err) } c.lister = lister } else { - c.lister = existing{c.logger, c.driverRoot} + c.lister = existing{c.logger, c.devRoot} } return c, nil } @@ -268,36 +273,48 @@ func (m linkCreator) setup() error { return nil } - s, err := system.New( - system.WithLogger(m.logger), - system.WithDryRun(m.dryRun), - ) - if err != nil { - return err - } - if m.loadKernelModules { - if err := s.LoadNVIDIAKernelModules(); err != nil { + modules := nvmodules.New( + nvmodules.WithLogger(m.logger), + nvmodules.WithDryRun(m.dryRun), + nvmodules.WithRoot(m.driverRoot), + ) + if err := modules.LoadAll(); err != nil { return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) } } if m.createDeviceNodes { - if err := s.CreateNVIDIAControlDeviceNodesAt(m.driverRoot); err != nil { + devices, err := nvdevices.New( + nvdevices.WithLogger(m.logger), + nvdevices.WithDryRun(m.dryRun), + nvdevices.WithDevRoot(m.devRoot), + ) + if err != nil { + return err + } + if err := devices.CreateNVIDIAControlDevices(); err != nil { return fmt.Errorf("failed to create NVIDIA device nodes: %v", err) } } - return nil } // WithDriverRoot sets the driver root path. +// This is the path in which kernel modules must be loaded. func WithDriverRoot(root string) Option { return func(c *linkCreator) { c.driverRoot = root } } +// WithDevRoot sets the root path for the /dev directory. +func WithDevRoot(root string) Option { + return func(c *linkCreator) { + c.devRoot = root + } +} + // WithDevCharPath sets the path at which the symlinks will be created. func WithDevCharPath(path string) Option { return func(c *linkCreator) { diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go index a88da8f5..a1af8b20 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go +++ b/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go @@ -30,8 +30,8 @@ type nodeLister interface { } type existing struct { - logger logger.Interface - driverRoot string + logger logger.Interface + devRoot string } // DeviceNodes returns a list of NVIDIA device nodes in the specified root. @@ -39,7 +39,7 @@ type existing struct { func (m existing) DeviceNodes() ([]deviceNode, error) { locator := lookup.NewCharDeviceLocator( lookup.WithLogger(m.logger), - lookup.WithRoot(m.driverRoot), + lookup.WithRoot(m.devRoot), lookup.WithOptional(true), ) @@ -54,7 +54,7 @@ func (m existing) DeviceNodes() ([]deviceNode, error) { } if len(devices) == 0 && len(capDevices) == 0 { - m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot) + m.logger.Infof("No NVIDIA devices found in %s", m.devRoot) return nil, nil } diff --git a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go index 94605b6c..4ca31131 100644 --- a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go +++ b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go @@ -20,7 +20,8 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system" + "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" + "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules" "github.com/urfave/cli/v2" ) @@ -96,19 +97,29 @@ func (m command) validateFlags(r *cli.Context, opts *options) error { } func (m command) run(c *cli.Context, opts *options) error { - s, err := system.New( - system.WithLogger(m.logger), - system.WithDryRun(opts.dryRun), - system.WithLoadKernelModules(opts.loadKernelModules), - ) - if err != nil { - return fmt.Errorf("failed to create library: %v", err) + if opts.loadKernelModules { + modules := nvmodules.New( + nvmodules.WithLogger(m.logger), + nvmodules.WithDryRun(opts.dryRun), + nvmodules.WithRoot(opts.driverRoot), + ) + if err := modules.LoadAll(); err != nil { + return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) + } } if opts.control { + devices, err := nvdevices.New( + nvdevices.WithLogger(m.logger), + nvdevices.WithDryRun(opts.dryRun), + nvdevices.WithDevRoot(opts.driverRoot), + ) + if err != nil { + return err + } m.logger.Infof("Creating control device nodes at %s", opts.driverRoot) - if err := s.CreateNVIDIAControlDeviceNodesAt(opts.driverRoot); err != nil { - return fmt.Errorf("failed to create control device nodes: %v", err) + if err := devices.CreateNVIDIAControlDevices(); err != nil { + return fmt.Errorf("failed to create NVIDIA control device nodes: %v", err) } } return nil diff --git a/internal/system/nvdevices/devices.go b/internal/system/nvdevices/devices.go new file mode 100644 index 00000000..1da030dc --- /dev/null +++ b/internal/system/nvdevices/devices.go @@ -0,0 +1,154 @@ +/** +# Copyright (c) NVIDIA CORPORATIOm. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvdevices + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" +) + +var errInvalidDeviceNode = errors.New("invalid device node") + +// Interface provides a set of utilities for interacting with NVIDIA devices on the system. +type Interface struct { + devices.Devices + + logger logger.Interface + + dryRun bool + // devRoot is the root directory where device nodes are expected to exist. + devRoot string + + mknoder +} + +// New constructs a new Interface struct with the specified options. +func New(opts ...Option) (*Interface, error) { + i := &Interface{} + for _, opt := range opts { + opt(i) + } + + if i.logger == nil { + i.logger = logger.New() + } + if i.devRoot == "" { + i.devRoot = "/" + } + if i.Devices == nil { + devices, err := devices.GetNVIDIADevices() + if err != nil { + return nil, fmt.Errorf("failed to create devices info: %v", err) + } + i.Devices = devices + } + + if i.dryRun { + i.mknoder = &mknodLogger{i.logger} + } else { + i.mknoder = &mknodUnix{} + } + return i, nil +} + +// CreateNVIDIAControlDevices creates the NVIDIA control device nodes at the configured devRoot. +func (m *Interface) CreateNVIDIAControlDevices() error { + controlNodes := []string{"nvidiactl", "nvidia-modeset", "nvidia-uvm", "nvidia-uvm-tools"} + for _, node := range controlNodes { + err := m.CreateNVIDIADevice(node) + if err != nil { + return fmt.Errorf("failed to create device node %s: %w", node, err) + } + } + return nil +} + +// CreateNVIDIADevice creates the specified NVIDIA device node at the configured devRoot. +func (m *Interface) CreateNVIDIADevice(node string) error { + node = filepath.Base(node) + if !strings.HasPrefix(node, "nvidia") { + return fmt.Errorf("invalid device node %q: %w", node, errInvalidDeviceNode) + } + + major, err := m.Major(node) + if err != nil { + return fmt.Errorf("failed to determine major: %w", err) + } + + minor, err := m.Minor(node) + if err != nil { + return fmt.Errorf("failed to determine minor: %w", err) + } + + return m.createDeviceNode(filepath.Join("dev", node), int(major), int(minor)) +} + +// createDeviceNode creates the specified device node with the require major and minor numbers. +// If a devRoot is configured, this is prepended to the path. +func (m *Interface) createDeviceNode(path string, major int, minor int) error { + path = filepath.Join(m.devRoot, path) + if _, err := os.Stat(path); err == nil { + m.logger.Infof("Skipping: %s already exists", path) + return nil + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to stat %s: %v", path, err) + } + + return m.Mknode(path, major, minor) +} + +// Major returns the major number for the specified NVIDIA device node. +// If the device node is not supported, an error is returned. +func (m *Interface) Major(node string) (int64, error) { + var valid bool + var major devices.Major + switch node { + case "nvidia-uvm", "nvidia-uvm-tools": + major, valid = m.Get(devices.NVIDIAUVM) + case "nvidia-modeset", "nvidiactl": + major, valid = m.Get(devices.NVIDIAGPU) + } + + if valid { + return int64(major), nil + } + + return 0, errInvalidDeviceNode +} + +// Minor returns the minor number for the specified NVIDIA device node. +// If the device node is not supported, an error is returned. +func (m *Interface) Minor(node string) (int64, error) { + switch node { + case "nvidia-modeset": + return devices.NVIDIAModesetMinor, nil + case "nvidia-uvm-tools": + return devices.NVIDIAUVMToolsMinor, nil + case "nvidia-uvm": + return devices.NVIDIAUVMMinor, nil + case "nvidiactl": + return devices.NVIDIACTLMinor, nil + } + + return 0, errInvalidDeviceNode +} diff --git a/internal/system/nvdevices/devices_test.go b/internal/system/nvdevices/devices_test.go new file mode 100644 index 00000000..5d94bc75 --- /dev/null +++ b/internal/system/nvdevices/devices_test.go @@ -0,0 +1,133 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvdevices + +import ( + "errors" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestCreateControlDevices(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + nvidiaDevices := &devices.DevicesMock{ + GetFunc: func(name devices.Name) (devices.Major, bool) { + devices := map[devices.Name]devices.Major{ + "nvidia-frontend": 195, + "nvidia-uvm": 243, + } + return devices[name], true + }, + } + + mknodeError := errors.New("mknode error") + + testCases := []struct { + description string + root string + devices devices.Devices + mknodeError error + expectedError error + expectedCalls []struct { + S string + N1 int + N2 int + } + }{ + { + description: "no root specified", + root: "", + devices: nvidiaDevices, + mknodeError: nil, + expectedCalls: []struct { + S string + N1 int + N2 int + }{ + {"/dev/nvidiactl", 195, 255}, + {"/dev/nvidia-modeset", 195, 254}, + {"/dev/nvidia-uvm", 243, 0}, + {"/dev/nvidia-uvm-tools", 243, 1}, + }, + }, + { + description: "some root specified", + root: "/some/root", + devices: nvidiaDevices, + mknodeError: nil, + expectedCalls: []struct { + S string + N1 int + N2 int + }{ + {"/some/root/dev/nvidiactl", 195, 255}, + {"/some/root/dev/nvidia-modeset", 195, 254}, + {"/some/root/dev/nvidia-uvm", 243, 0}, + {"/some/root/dev/nvidia-uvm-tools", 243, 1}, + }, + }, + { + description: "mknod error returns error", + devices: nvidiaDevices, + mknodeError: mknodeError, + expectedError: mknodeError, + // We expect the first call to this to fail, and the rest to be skipped + expectedCalls: []struct { + S string + N1 int + N2 int + }{ + {"/dev/nvidiactl", 195, 255}, + }, + }, + { + description: "missing major returns error", + devices: &devices.DevicesMock{ + GetFunc: func(name devices.Name) (devices.Major, bool) { + return 0, false + }, + }, + expectedError: errInvalidDeviceNode, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + mknode := &mknoderMock{ + MknodeFunc: func(string, int, int) error { + return tc.mknodeError + }, + } + + d, _ := New( + WithLogger(logger), + WithDevRoot(tc.root), + WithDevices(tc.devices), + ) + d.mknoder = mknode + + err := d.CreateNVIDIAControlDevices() + require.ErrorIs(t, err, tc.expectedError) + require.EqualValues(t, tc.expectedCalls, mknode.MknodeCalls()) + }) + } + +} diff --git a/internal/system/nvdevices/mknod.go b/internal/system/nvdevices/mknod.go new file mode 100644 index 00000000..e5990ea0 --- /dev/null +++ b/internal/system/nvdevices/mknod.go @@ -0,0 +1,46 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvdevices + +import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "golang.org/x/sys/unix" +) + +//go:generate moq -stub -out mknod_mock.go . mknoder +type mknoder interface { + Mknode(string, int, int) error +} + +type mknodLogger struct { + logger.Interface +} + +func (m *mknodLogger) Mknode(path string, major, minor int) error { + m.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor) + return nil +} + +type mknodUnix struct{} + +func (m *mknodUnix) Mknode(path string, major, minor int) error { + err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor)))) + if err != nil { + return err + } + return unix.Chmod(path, 0666) +} diff --git a/internal/system/nvdevices/mknod_mock.go b/internal/system/nvdevices/mknod_mock.go new file mode 100644 index 00000000..4bb384fa --- /dev/null +++ b/internal/system/nvdevices/mknod_mock.go @@ -0,0 +1,89 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package nvdevices + +import ( + "sync" +) + +// Ensure, that mknoderMock does implement mknoder. +// If this is not the case, regenerate this file with moq. +var _ mknoder = &mknoderMock{} + +// mknoderMock is a mock implementation of mknoder. +// +// func TestSomethingThatUsesmknoder(t *testing.T) { +// +// // make and configure a mocked mknoder +// mockedmknoder := &mknoderMock{ +// MknodeFunc: func(s string, n1 int, n2 int) error { +// panic("mock out the Mknode method") +// }, +// } +// +// // use mockedmknoder in code that requires mknoder +// // and then make assertions. +// +// } +type mknoderMock struct { + // MknodeFunc mocks the Mknode method. + MknodeFunc func(s string, n1 int, n2 int) error + + // calls tracks calls to the methods. + calls struct { + // Mknode holds details about calls to the Mknode method. + Mknode []struct { + // S is the s argument value. + S string + // N1 is the n1 argument value. + N1 int + // N2 is the n2 argument value. + N2 int + } + } + lockMknode sync.RWMutex +} + +// Mknode calls MknodeFunc. +func (mock *mknoderMock) Mknode(s string, n1 int, n2 int) error { + callInfo := struct { + S string + N1 int + N2 int + }{ + S: s, + N1: n1, + N2: n2, + } + mock.lockMknode.Lock() + mock.calls.Mknode = append(mock.calls.Mknode, callInfo) + mock.lockMknode.Unlock() + if mock.MknodeFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.MknodeFunc(s, n1, n2) +} + +// MknodeCalls gets all the calls that were made to Mknode. +// Check the length with: +// +// len(mockedmknoder.MknodeCalls()) +func (mock *mknoderMock) MknodeCalls() []struct { + S string + N1 int + N2 int +} { + var calls []struct { + S string + N1 int + N2 int + } + mock.lockMknode.RLock() + calls = mock.calls.Mknode + mock.lockMknode.RUnlock() + return calls +} diff --git a/internal/system/nvdevices/options.go b/internal/system/nvdevices/options.go new file mode 100644 index 00000000..0bcf319d --- /dev/null +++ b/internal/system/nvdevices/options.go @@ -0,0 +1,53 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvdevices + +import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" +) + +// Option is a function that sets an option on the Interface struct. +type Option func(*Interface) + +// WithDryRun sets the dry run option for the Interface struct. +func WithDryRun(dryRun bool) Option { + return func(i *Interface) { + i.dryRun = dryRun + } +} + +// WithLogger sets the logger for the Interface struct. +func WithLogger(logger logger.Interface) Option { + return func(i *Interface) { + i.logger = logger + } +} + +// WithDevRoot sets the root directory for the NVIDIA device nodes. +func WithDevRoot(devRoot string) Option { + return func(i *Interface) { + i.devRoot = devRoot + } +} + +// WithDevices sets the devices for the Interface struct. +func WithDevices(devices devices.Devices) Option { + return func(i *Interface) { + i.Devices = devices + } +} diff --git a/internal/system/nvmodules/cmd.go b/internal/system/nvmodules/cmd.go new file mode 100644 index 00000000..23df4f25 --- /dev/null +++ b/internal/system/nvmodules/cmd.go @@ -0,0 +1,49 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvmodules + +import ( + "fmt" + "os/exec" + "strings" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" +) + +//go:generate moq -stub -out cmd_mock.go . cmder +type cmder interface { + // Run executes the command and returns the stdout, stderr, and an error if any + Run(string, ...string) error +} + +type cmderLogger struct { + logger.Interface +} + +func (c *cmderLogger) Run(cmd string, args ...string) error { + c.Infof("Running: %v %v", cmd, strings.Join(args, " ")) + return nil +} + +type cmderExec struct{} + +func (c *cmderExec) Run(cmd string, args ...string) error { + if output, err := exec.Command(cmd, args...).CombinedOutput(); err != nil { + return fmt.Errorf("%w; output=%v", err, string(output)) + } + return nil +} diff --git a/internal/system/nvmodules/cmd_mock.go b/internal/system/nvmodules/cmd_mock.go new file mode 100644 index 00000000..077e58a7 --- /dev/null +++ b/internal/system/nvmodules/cmd_mock.go @@ -0,0 +1,83 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package nvmodules + +import ( + "sync" +) + +// Ensure, that cmderMock does implement cmder. +// If this is not the case, regenerate this file with moq. +var _ cmder = &cmderMock{} + +// cmderMock is a mock implementation of cmder. +// +// func TestSomethingThatUsescmder(t *testing.T) { +// +// // make and configure a mocked cmder +// mockedcmder := &cmderMock{ +// RunFunc: func(s string, strings ...string) error { +// panic("mock out the Run method") +// }, +// } +// +// // use mockedcmder in code that requires cmder +// // and then make assertions. +// +// } +type cmderMock struct { + // RunFunc mocks the Run method. + RunFunc func(s string, strings ...string) error + + // calls tracks calls to the methods. + calls struct { + // Run holds details about calls to the Run method. + Run []struct { + // S is the s argument value. + S string + // Strings is the strings argument value. + Strings []string + } + } + lockRun sync.RWMutex +} + +// Run calls RunFunc. +func (mock *cmderMock) Run(s string, strings ...string) error { + callInfo := struct { + S string + Strings []string + }{ + S: s, + Strings: strings, + } + mock.lockRun.Lock() + mock.calls.Run = append(mock.calls.Run, callInfo) + mock.lockRun.Unlock() + if mock.RunFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.RunFunc(s, strings...) +} + +// RunCalls gets all the calls that were made to Run. +// Check the length with: +// +// len(mockedcmder.RunCalls()) +func (mock *cmderMock) RunCalls() []struct { + S string + Strings []string +} { + var calls []struct { + S string + Strings []string + } + mock.lockRun.RLock() + calls = mock.calls.Run + mock.lockRun.RUnlock() + return calls +} diff --git a/internal/system/nvmodules/modules.go b/internal/system/nvmodules/modules.go new file mode 100644 index 00000000..9b81d945 --- /dev/null +++ b/internal/system/nvmodules/modules.go @@ -0,0 +1,93 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvmodules + +import ( + "fmt" + "strings" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" +) + +// Interface provides a set of utilities for interacting with NVIDIA modules on the system. +type Interface struct { + logger logger.Interface + + dryRun bool + root string + + cmder +} + +// New constructs a new Interface struct with the specified options. +func New(opts ...Option) *Interface { + m := &Interface{} + for _, opt := range opts { + opt(m) + } + + if m.logger == nil { + m.logger = logger.New() + } + if m.root == "" { + m.root = "/" + } + + if m.dryRun { + m.cmder = &cmderLogger{m.logger} + } else { + m.cmder = &cmderExec{} + } + return m +} + +// LoadAll loads all the NVIDIA kernel modules. +func (m *Interface) LoadAll() error { + modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"} + + for _, module := range modules { + if err := m.Load(module); err != nil { + return fmt.Errorf("failed to load module %s: %w", module, err) + } + } + return nil +} + +var errInvalidModule = fmt.Errorf("invalid module") + +// Load loads the specified NVIDIA kernel module. +// If the root is specified we first chroot into this root. +func (m *Interface) Load(module string) error { + if !strings.HasPrefix(module, "nvidia") { + return errInvalidModule + } + + var args []string + if m.root != "/" { + args = append(args, "chroot", m.root) + } + args = append(args, "/sbin/modprobe", module) + + m.logger.Debugf("Loading kernel module %s: %v", module, args) + err := m.Run(args[0], args[1:]...) + if err != nil { + m.logger.Debugf("Failed to load kernel module %s: %v", module, err) + return err + } + + return nil +} diff --git a/internal/system/nvmodules/modules_test.go b/internal/system/nvmodules/modules_test.go new file mode 100644 index 00000000..6cf67059 --- /dev/null +++ b/internal/system/nvmodules/modules_test.go @@ -0,0 +1,178 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvmodules + +import ( + "errors" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestLoadAll(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + runError := errors.New("run error") + + testCases := []struct { + description string + root string + runError error + expectedError error + expectedCalls []struct { + S string + Strings []string + } + }{ + { + description: "no root specified", + root: "", + expectedCalls: []struct { + S string + Strings []string + }{ + {"/sbin/modprobe", []string{"nvidia"}}, + {"/sbin/modprobe", []string{"nvidia-uvm"}}, + {"/sbin/modprobe", []string{"nvidia-modeset"}}, + }, + }, + { + description: "root causes chroot", + root: "/some/root", + expectedCalls: []struct { + S string + Strings []string + }{ + {"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}}, + {"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-uvm"}}, + {"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-modeset"}}, + }, + }, + { + description: "run failure is returned", + root: "", + runError: runError, + expectedError: runError, + expectedCalls: []struct { + S string + Strings []string + }{ + {"/sbin/modprobe", []string{"nvidia"}}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + cmder := &cmderMock{ + RunFunc: func(cmd string, args ...string) error { + return tc.runError + }, + } + m := New( + WithLogger(logger), + WithRoot(tc.root), + ) + m.cmder = cmder + + err := m.LoadAll() + require.ErrorIs(t, err, tc.expectedError) + + require.EqualValues(t, tc.expectedCalls, cmder.RunCalls()) + }) + } +} + +func TestLoad(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + runError := errors.New("run error") + + testCases := []struct { + description string + root string + module string + runError error + expectedError error + expectedCalls []struct { + S string + Strings []string + } + }{ + { + description: "no root specified", + root: "", + module: "nvidia", + expectedCalls: []struct { + S string + Strings []string + }{ + {"/sbin/modprobe", []string{"nvidia"}}, + }, + }, + { + description: "root causes chroot", + root: "/some/root", + module: "nvidia", + expectedCalls: []struct { + S string + Strings []string + }{ + {"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}}, + }, + }, + { + description: "run failure is returned", + root: "", + module: "nvidia", + runError: runError, + expectedError: runError, + expectedCalls: []struct { + S string + Strings []string + }{ + {"/sbin/modprobe", []string{"nvidia"}}, + }, + }, + { + description: "module prefis is checked", + module: "not-nvidia", + expectedError: errInvalidModule, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + cmder := &cmderMock{ + RunFunc: func(cmd string, args ...string) error { + return tc.runError + }, + } + m := New( + WithLogger(logger), + WithRoot(tc.root), + ) + m.cmder = cmder + + err := m.Load(tc.module) + require.ErrorIs(t, err, tc.expectedError) + + require.EqualValues(t, tc.expectedCalls, cmder.RunCalls()) + }) + } +} diff --git a/internal/system/options.go b/internal/system/nvmodules/options.go similarity index 74% rename from internal/system/options.go rename to internal/system/nvmodules/options.go index 5d261f0d..4633f023 100644 --- a/internal/system/options.go +++ b/internal/system/nvmodules/options.go @@ -14,30 +14,30 @@ # limitations under the License. **/ -package system +package nvmodules import "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" -// Option is a functional option for the system command +// Option is a function that sets an option on the Interface struct. type Option func(*Interface) -// WithLogger sets the logger for the system command -func WithLogger(logger logger.Interface) Option { - return func(i *Interface) { - i.logger = logger - } -} - -// WithDryRun sets the dry run flag +// WithDryRun sets the dry run option for the Interface struct. func WithDryRun(dryRun bool) Option { return func(i *Interface) { i.dryRun = dryRun } } -// WithLoadKernelModules sets the load kernel modules flag -func WithLoadKernelModules(loadKernelModules bool) Option { +// WithLogger sets the logger for the Interface struct. +func WithLogger(logger logger.Interface) Option { return func(i *Interface) { - i.loadKernelModules = loadKernelModules + i.logger = logger + } +} + +// WithRoot sets the root directory for the NVIDIA device nodes. +func WithRoot(root string) Option { + return func(i *Interface) { + i.root = root } } diff --git a/internal/system/system.go b/internal/system/system.go deleted file mode 100644 index df39de9d..00000000 --- a/internal/system/system.go +++ /dev/null @@ -1,176 +0,0 @@ -/** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package system - -import ( - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "golang.org/x/sys/unix" -) - -// Interface is the interface for the system command -type Interface struct { - logger logger.Interface - dryRun bool - loadKernelModules bool - nvidiaDevices nvidiaDevices -} - -// New constructs a system command with the specified options -func New(opts ...Option) (*Interface, error) { - i := &Interface{ - logger: logger.New(), - } - for _, opt := range opts { - opt(i) - } - - if i.loadKernelModules { - if err := i.LoadNVIDIAKernelModules(); err != nil { - return nil, fmt.Errorf("failed to load kernel modules: %v", err) - } - } - - devices, err := devices.GetNVIDIADevices() - if err != nil { - return nil, fmt.Errorf("failed to create devices info: %v", err) - } - i.nvidiaDevices = nvidiaDevices{devices} - - return i, nil -} - -// CreateNVIDIAControlDeviceNodesAt creates the NVIDIA control device nodes associated with the NVIDIA driver at the specified root. -func (m *Interface) CreateNVIDIAControlDeviceNodesAt(root string) error { - controlNodes := []string{"/dev/nvidiactl", "/dev/nvidia-modeset", "/dev/nvidia-uvm", "/dev/nvidia-uvm-tools"} - - for _, node := range controlNodes { - path := filepath.Join(root, node) - err := m.CreateNVIDIADeviceNode(path) - if err != nil { - return fmt.Errorf("failed to create device node %s: %v", path, err) - } - } - - return nil -} - -// CreateNVIDIADeviceNode creates a specified device node associated with the NVIDIA driver. -func (m *Interface) CreateNVIDIADeviceNode(path string) error { - node := filepath.Base(path) - if !strings.HasPrefix(node, "nvidia") { - return fmt.Errorf("invalid device node %q", node) - } - - major, err := m.nvidiaDevices.Major(node) - if err != nil { - return fmt.Errorf("failed to determine major: %v", err) - } - - minor, err := m.nvidiaDevices.Minor(node) - if err != nil { - return fmt.Errorf("failed to determine minor: %v", err) - } - - return m.createDeviceNode(path, int(major), int(minor)) -} - -func (m *Interface) createDeviceNode(path string, major int, minor int) error { - if m.dryRun { - m.logger.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor) - return nil - } - - if _, err := os.Stat(path); err == nil { - m.logger.Infof("Skipping: %s already exists", path) - return nil - } else if !os.IsNotExist(err) { - return fmt.Errorf("failed to stat %s: %v", path, err) - } - - err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor)))) - if err != nil { - return err - } - return unix.Chmod(path, 0666) -} - -// LoadNVIDIAKernelModules loads the NVIDIA kernel modules. -func (m *Interface) LoadNVIDIAKernelModules() error { - modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"} - - for _, module := range modules { - if m.dryRun { - m.logger.Infof("Running: /sbin/modprobe %s", module) - continue - } - cmd := exec.Command("/sbin/modprobe", module) - - if output, err := cmd.CombinedOutput(); err != nil { - m.logger.Debugf("Failed to load kernel module %s: %v", module, string(output)) - return fmt.Errorf("failed to load kernel module %s: %v", module, err) - } - } - - return nil -} - -type nvidiaDevices struct { - devices.Devices -} - -// Major returns the major number for the specified NVIDIA device node. -// If the device node is not supported, an error is returned. -func (n *nvidiaDevices) Major(node string) (int64, error) { - var valid bool - var major devices.Major - switch node { - case "nvidia-uvm", "nvidia-uvm-tools": - major, valid = n.Get(devices.NVIDIAUVM) - case "nvidia-modeset", "nvidiactl": - major, valid = n.Get(devices.NVIDIAGPU) - } - - if !valid { - return 0, fmt.Errorf("invalid device node %q", node) - } - - return int64(major), nil -} - -// Minor returns the minor number for the specified NVIDIA device node. -// If the device node is not supported, an error is returned. -func (n *nvidiaDevices) Minor(node string) (int64, error) { - switch node { - case "nvidia-modeset": - return devices.NVIDIAModesetMinor, nil - case "nvidia-uvm-tools": - return devices.NVIDIAUVMToolsMinor, nil - case "nvidia-uvm": - return devices.NVIDIAUVMMinor, nil - case "nvidiactl": - return devices.NVIDIACTLMinor, nil - } - - return 0, fmt.Errorf("invalid device node %q", node) -} diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index 5f361f8f..b05062f3 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -23,7 +23,7 @@ import ( "path/filepath" "strings" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system" + "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" @@ -683,11 +683,13 @@ func generateCDISpec(opts *options, nvidiaCTKPath string) error { } log.Infof("Creating control device nodes at %v", opts.DriverRootCtrPath) - s, err := system.New() + devices, err := nvdevices.New( + nvdevices.WithDevRoot(opts.DriverRootCtrPath), + ) if err != nil { return fmt.Errorf("failed to create library: %v", err) } - if err := s.CreateNVIDIAControlDeviceNodesAt(opts.DriverRootCtrPath); err != nil { + if err := devices.CreateNVIDIAControlDevices(); err != nil { return fmt.Errorf("failed to create control device nodes: %v", err) }