diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index b52eb563..d70d4738 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -22,6 +22,7 @@ import ( "path/filepath" "strings" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/info/drm" "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc" @@ -31,18 +32,23 @@ import ( ) // NewGraphicsDiscoverer returns the discoverer for graphics tools such as Vulkan. -func NewGraphicsDiscoverer(logger logger.Interface, devices image.VisibleDevices, driverRoot string, nvidiaCTKPath string) (Discover, error) { +func NewGraphicsDiscoverer(logger logger.Interface, cfg *config.Config, devices image.VisibleDevices) (Discover, error) { + driverRoot := cfg.NVIDIAContainerCLIConfig.Root + // In standard usage, the devRoot is the same as the driverRoot. + devRoot := driverRoot + nvidiaCTKPath := cfg.NVIDIACTKConfig.Path + mounts, err := NewGraphicsMountsDiscoverer(logger, driverRoot, nvidiaCTKPath) if err != nil { return nil, fmt.Errorf("failed to create mounts discoverer: %v", err) } - drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, driverRoot) + drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot) if err != nil { return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err) } - drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, driverRoot, nvidiaCTKPath) + drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCTKPath) discover := Merge( Merge(drmDeviceNodes, drmByPathSymlinks), @@ -99,16 +105,16 @@ type drmDevicesByPath struct { None logger logger.Interface nvidiaCTKPath string - driverRoot string + devRoot string devicesFrom Discover } // newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer -func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, driverRoot string, nvidiaCTKPath string) Discover { +func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCTKPath string) Discover { d := drmDevicesByPath{ logger: logger, nvidiaCTKPath: nvidiaCTKPath, - driverRoot: driverRoot, + devRoot: devRoot, devicesFrom: devices, } @@ -155,7 +161,7 @@ func (d drmDevicesByPath) getSpecificLinkArgs(devices []Device) ([]string, error linkLocator := lookup.NewFileLocator( lookup.WithLogger(d.logger), - lookup.WithRoot(d.driverRoot), + lookup.WithRoot(d.devRoot), ) candidates, err := linkLocator.Locate("/dev/dri/by-path/pci-*-*") if err != nil { @@ -181,21 +187,21 @@ func (d drmDevicesByPath) getSpecificLinkArgs(devices []Device) ([]string, error } // newDRMDeviceDiscoverer creates a discoverer for the DRM devices associated with the requested devices. -func newDRMDeviceDiscoverer(logger logger.Interface, devices image.VisibleDevices, driverRoot string) (Discover, error) { +func newDRMDeviceDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string) (Discover, error) { allDevices := NewDeviceDiscoverer( logger, lookup.NewCharDeviceLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(devRoot), ), - driverRoot, + devRoot, []string{ "/dev/dri/card*", "/dev/dri/renderD*", }, ) - filter, err := newDRMDeviceFilter(logger, devices, driverRoot) + filter, err := newDRMDeviceFilter(logger, devices, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct DRM device filter: %v", err) } @@ -211,8 +217,8 @@ func newDRMDeviceDiscoverer(logger logger.Interface, devices image.VisibleDevice } // newDRMDeviceFilter creates a filter that matches DRM devices nodes for the visible devices. -func newDRMDeviceFilter(logger logger.Interface, devices image.VisibleDevices, driverRoot string) (Filter, error) { - gpuInformationPaths, err := proc.GetInformationFilePaths(driverRoot) +func newDRMDeviceFilter(logger logger.Interface, devices image.VisibleDevices, devRoot string) (Filter, error) { + gpuInformationPaths, err := proc.GetInformationFilePaths(devRoot) if err != nil { return nil, fmt.Errorf("failed to read GPU information: %v", err) } diff --git a/internal/modifier/gds.go b/internal/modifier/gds.go index dd03a731..ac431405 100644 --- a/internal/modifier/gds.go +++ b/internal/modifier/gds.go @@ -42,7 +42,9 @@ func NewGDSModifier(logger logger.Interface, cfg *config.Config, image image.CUD return nil, nil } - d, err := discover.NewGDSDiscoverer(logger, cfg.NVIDIAContainerCLIConfig.Root) + driverRoot := cfg.NVIDIAContainerCLIConfig.Root + devRoot := cfg.NVIDIAContainerCLIConfig.Root + d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %v", err) } diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 57776a72..2852ba2c 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -36,9 +36,8 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image imag d, err := discover.NewGraphicsDiscoverer( logger, + cfg, image.DevicesFromEnvvars(visibleDevicesEnvvar), - cfg.NVIDIAContainerCLIConfig.Root, - cfg.NVIDIACTKConfig.Path, ) if err != nil { return nil, fmt.Errorf("failed to construct discoverer: %v", err) diff --git a/internal/platform-support/tegra/tegra.go b/internal/platform-support/tegra/tegra.go index 563c6e28..6fb42674 100644 --- a/internal/platform-support/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -29,6 +29,7 @@ type tegraOptions struct { logger logger.Interface csvFiles []string driverRoot string + devRoot string nvidiaCTKPath string librarySearchPaths []string ignorePatterns ignoreMountSpecPatterns @@ -50,6 +51,10 @@ func New(opts ...Option) (discover.Discover, error) { opt(o) } + if o.devRoot == "" { + o.devRoot = o.driverRoot + } + if o.symlinkLocator == nil { o.symlinkLocator = lookup.NewSymlinkLocator( lookup.WithLogger(o.logger), @@ -112,6 +117,14 @@ func WithDriverRoot(driverRoot string) Option { } } +// WithDevRoot sets the /dev root. +// If this is unset, the driver root is assumed. +func WithDevRoot(driverRoot string) Option { + return func(o *tegraOptions) { + o.driverRoot = driverRoot + } +} + // WithCSVFiles sets the CSV files for the discoverer. func WithCSVFiles(csvFiles []string) Option { return func(o *tegraOptions) { diff --git a/pkg/nvcdi/common-nvml.go b/pkg/nvcdi/common-nvml.go index 3c9fe6bf..f7a22c8a 100644 --- a/pkg/nvcdi/common-nvml.go +++ b/pkg/nvcdi/common-nvml.go @@ -20,22 +20,19 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" - - "github.com/NVIDIA/go-nvlib/pkg/nvml" ) // newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device. // This includes driver libraries and meta devices, for example. -func newCommonNVMLDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { +func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) { metaDevices := discover.NewDeviceDiscoverer( - logger, + l.logger, lookup.NewCharDeviceLocator( - lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithLogger(l.logger), + lookup.WithRoot(l.devRoot), ), - driverRoot, + l.devRoot, []string{ "/dev/nvidia-modeset", "/dev/nvidia-uvm-tools", @@ -44,12 +41,12 @@ func newCommonNVMLDiscoverer(logger logger.Interface, driverRoot string, nvidiaC }, ) - graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(logger, driverRoot, nvidiaCTKPath) + graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath) if err != nil { - logger.Warningf("failed to create discoverer for graphics mounts: %v", err) + l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err) } - driverFiles, err := NewDriverDiscoverer(logger, driverRoot, nvidiaCTKPath, nvmllib) + driverFiles, err := NewDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err) } diff --git a/pkg/nvcdi/device-wsl.go b/pkg/nvcdi/device-wsl.go index 418016b5..b1d93468 100644 --- a/pkg/nvcdi/device-wsl.go +++ b/pkg/nvcdi/device-wsl.go @@ -26,11 +26,11 @@ const ( ) // newDXGDeviceDiscoverer returns a Discoverer for DXG devices under WSL2. -func newDXGDeviceDiscoverer(logger logger.Interface, driverRoot string) discover.Discover { +func newDXGDeviceDiscoverer(logger logger.Interface, devRoot string) discover.Discover { deviceNodes := discover.NewCharDeviceDiscoverer( logger, []string{dxgDeviceNode}, - driverRoot, + devRoot, ) return deviceNodes diff --git a/pkg/nvcdi/full-gpu-nvml.go b/pkg/nvcdi/full-gpu-nvml.go index b808153b..f22e449c 100644 --- a/pkg/nvcdi/full-gpu-nvml.go +++ b/pkg/nvcdi/full-gpu-nvml.go @@ -54,7 +54,7 @@ func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, erro // GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'. func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) { - device, err := newFullGPUDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, d) + device, err := newFullGPUDiscoverer(l.logger, l.devRoot, l.nvidiaCTKPath, d) if err != nil { return nil, fmt.Errorf("failed to create device discoverer: %v", err) } @@ -70,7 +70,7 @@ func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error // byPathHookDiscoverer discovers the entities required for injecting by-path DRM device links type byPathHookDiscoverer struct { logger logger.Interface - driverRoot string + devRoot string nvidiaCTKPath string pciBusID string deviceNodes discover.Discover @@ -79,7 +79,7 @@ type byPathHookDiscoverer struct { var _ discover.Discover = (*byPathHookDiscoverer)(nil) // newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device. -func newFullGPUDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, d device.Device) (discover.Discover, error) { +func newFullGPUDiscoverer(logger logger.Interface, devRoot string, nvidiaCTKPath string, d device.Device) (discover.Discover, error) { // TODO: The functionality to get device paths should be integrated into the go-nvlib/pkg/device.Device interface. // This will allow reuse here and in other code where the paths are queried such as the NVIDIA device plugin. minor, ret := d.GetMinorNumber() @@ -104,12 +104,12 @@ func newFullGPUDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKP deviceNodes := discover.NewCharDeviceDiscoverer( logger, deviceNodePaths, - driverRoot, + devRoot, ) byPathHooks := &byPathHookDiscoverer{ logger: logger, - driverRoot: driverRoot, + devRoot: devRoot, nvidiaCTKPath: nvidiaCTKPath, pciBusID: pciBusID, deviceNodes: deviceNodes, @@ -117,7 +117,7 @@ func newFullGPUDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKP deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer( logger, - driverRoot, + devRoot, nvidiaCTKPath, deviceNodes, ) @@ -189,7 +189,7 @@ func (d *byPathHookDiscoverer) deviceNodeLinks() ([]string, error) { var links []string for _, c := range candidates { - linkPath := filepath.Join(d.driverRoot, c) + linkPath := filepath.Join(d.devRoot, c) device, err := os.Readlink(linkPath) if err != nil { d.logger.Warningf("Failed to evaluate symlink %v; ignoring", linkPath) diff --git a/pkg/nvcdi/gds.go b/pkg/nvcdi/gds.go index 3cb5bcf2..c50ab5e5 100644 --- a/pkg/nvcdi/gds.go +++ b/pkg/nvcdi/gds.go @@ -33,7 +33,7 @@ var _ Interface = (*gdslib)(nil) // GetAllDeviceSpecs returns the device specs for all available devices. func (l *gdslib) GetAllDeviceSpecs() ([]specs.Device, error) { - discoverer, err := discover.NewGDSDiscoverer(l.logger, l.driverRoot) + discoverer, err := discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot) if err != nil { return nil, fmt.Errorf("failed to create GPUDirect Storage discoverer: %v", err) } diff --git a/pkg/nvcdi/lib-csv.go b/pkg/nvcdi/lib-csv.go index e0b7fbc6..14b6b5a0 100644 --- a/pkg/nvcdi/lib-csv.go +++ b/pkg/nvcdi/lib-csv.go @@ -42,6 +42,7 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) { d, err := tegra.New( tegra.WithLogger(l.logger), tegra.WithDriverRoot(l.driverRoot), + tegra.WithDevRoot(l.devRoot), tegra.WithNVIDIACTKPath(l.nvidiaCTKPath), tegra.WithCSVFiles(l.csvFiles), tegra.WithLibrarySearchPaths(l.librarySearchPaths...), diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index c49833cf..82cc9b6e 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -66,7 +66,7 @@ func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { // GetCommonEdits generates a CDI specification that can be used for ANY devices func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) { - common, err := newCommonNVMLDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) + common, err := l.newCommonNVMLDiscoverer() if err != nil { return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err) } diff --git a/pkg/nvcdi/lib-wsl.go b/pkg/nvcdi/lib-wsl.go index 23ee79bd..b2ecf173 100644 --- a/pkg/nvcdi/lib-wsl.go +++ b/pkg/nvcdi/lib-wsl.go @@ -37,7 +37,7 @@ func (l *wsllib) GetSpec() (spec.Interface, error) { // GetAllDeviceSpecs returns the device specs for all available devices. func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { - device := newDXGDeviceDiscoverer(l.logger, l.driverRoot) + device := newDXGDeviceDiscoverer(l.logger, l.devRoot) deviceEdits, err := edits.FromDiscoverer(device) if err != nil { return nil, fmt.Errorf("failed to create container edits for DXG device: %v", err) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index f5b53245..8a60ac6b 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -44,6 +44,7 @@ type nvcdilib struct { devicelib device.Interface deviceNamer DeviceNamer driverRoot string + devRoot string nvidiaCTKPath string librarySearchPaths []string @@ -76,6 +77,9 @@ func New(opts ...Option) (Interface, error) { if l.driverRoot == "" { l.driverRoot = "/" } + if l.devRoot == "" { + l.devRoot = l.driverRoot + } if l.nvidiaCTKPath == "" { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } diff --git a/pkg/nvcdi/management.go b/pkg/nvcdi/management.go index 626a6c2b..8050a5b6 100644 --- a/pkg/nvcdi/management.go +++ b/pkg/nvcdi/management.go @@ -117,12 +117,12 @@ func (m *managementlib) newManagementDeviceDiscoverer() (discover.Discover, erro "/dev/nvidia-uvm", "/dev/nvidiactl", }, - m.driverRoot, + m.devRoot, ) deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer( m.logger, - m.driverRoot, + m.devRoot, m.nvidiaCTKPath, deviceNodes, ) diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 03cfc362..86bb877d 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -47,6 +47,13 @@ func WithDriverRoot(root string) Option { } } +// WithDevRoot sets the root where /dev is located. +func WithDevRoot(root string) Option { + return func(l *nvcdilib) { + l.devRoot = root + } +} + // WithLogger sets the logger for the library func WithLogger(logger logger.Interface) Option { return func(l *nvcdilib) { diff --git a/pkg/nvcdi/workarounds-device-folder-permissions.go b/pkg/nvcdi/workarounds-device-folder-permissions.go index 812bbbaf..e183ed6e 100644 --- a/pkg/nvcdi/workarounds-device-folder-permissions.go +++ b/pkg/nvcdi/workarounds-device-folder-permissions.go @@ -26,7 +26,7 @@ import ( type deviceFolderPermissions struct { logger logger.Interface - driverRoot string + devRoot string nvidiaCTKPath string devices discover.Discover } @@ -39,10 +39,10 @@ var _ discover.Discover = (*deviceFolderPermissions)(nil) // The nested devices that are applicable to the NVIDIA GPU devices are: // - DRM devices at /dev/dri/* // - NVIDIA Caps devices at /dev/nvidia-caps/* -func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, devices discover.Discover) discover.Discover { +func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, nvidiaCTKPath string, devices discover.Discover) discover.Discover { d := &deviceFolderPermissions{ logger: logger, - driverRoot: driverRoot, + devRoot: devRoot, nvidiaCTKPath: nvidiaCTKPath, devices: devices, }