diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index 1860715b..4e0c47b9 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -28,6 +28,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" ) // NewDRMNodesDiscoverer returns a discoverrer for the DRM device nodes associated with the specified visible devices. @@ -48,15 +49,11 @@ func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices } // NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan. -func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) (Discover, error) { - locator, err := lookup.NewLibraryLocator(logger, driverRoot) - if err != nil { - return nil, fmt.Errorf("failed to construct library locator: %v", err) - } +func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) { libraries := NewMounts( logger, - locator, - driverRoot, + driver.Libraries(), + driver.Root, []string{ "libnvidia-egl-gbm.so", }, @@ -66,10 +63,10 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvi logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), lookup.WithSearchPaths("/etc", "/usr/share"), ), - driverRoot, + driver.Root, []string{ "glvnd/egl_vendor.d/10_nvidia.json", "vulkan/icd.d/nvidia_icd.json", @@ -81,7 +78,7 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvi }, ) - xorg := optionalXorgDiscoverer(logger, driverRoot, nvidiaCTKPath) + xorg := optionalXorgDiscoverer(logger, driver, nvidiaCTKPath) discover := Merge( libraries, @@ -249,8 +246,8 @@ var _ Discover = (*xorgHooks)(nil) // optionalXorgDiscoverer creates a discoverer for Xorg libraries. // If the creation of the discoverer fails, a None discoverer is returned. -func optionalXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) Discover { - xorg, err := newXorgDiscoverer(logger, driverRoot, nvidiaCTKPath) +func optionalXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) Discover { + xorg, err := newXorgDiscoverer(logger, driver, nvidiaCTKPath) if err != nil { logger.Warningf("Failed to create Xorg discoverer: %v; skipping xorg libraries", err) return None{} @@ -258,10 +255,9 @@ func optionalXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCT return xorg } -func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) (Discover, error) { +func newXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) { libCudaPaths, err := cuda.New( - cuda.WithLogger(logger), - cuda.WithDriverRoot(driverRoot), + driver.Libraries(), ).Locate(".*.*") if err != nil { return nil, fmt.Errorf("failed to locate libcuda.so: %v", err) @@ -278,11 +274,11 @@ func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), lookup.WithSearchPaths(libRoot, "/usr/lib/x86_64-linux-gnu"), lookup.WithCount(1), ), - driverRoot, + driver.Root, []string{ "nvidia/xorg/nvidia_drv.so", fmt.Sprintf("nvidia/xorg/libglxserver_nvidia.so.%s", version), @@ -298,10 +294,10 @@ func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), lookup.WithSearchPaths("/usr/share"), ), - driverRoot, + driver.Root, []string{"X11/xorg.conf.d/10-nvidia.conf"}, ) diff --git a/internal/lookup/cuda/cuda.go b/internal/lookup/cuda/cuda.go index b95e81b4..68c4db35 100644 --- a/internal/lookup/cuda/cuda.go +++ b/internal/lookup/cuda/cuda.go @@ -17,55 +17,19 @@ package cuda import ( - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" ) type cudaLocator struct { lookup.Locator - logger logger.Interface - driverRoot string -} - -// Options is a function that configures a cudaLocator. -type Options func(*cudaLocator) - -// WithLogger is an option that configures the logger used by the locator. -func WithLogger(logger logger.Interface) Options { - return func(c *cudaLocator) { - c.logger = logger - } -} - -// WithDriverRoot is an option that configures the driver root used by the locator. -func WithDriverRoot(driverRoot string) Options { - return func(c *cudaLocator) { - c.driverRoot = driverRoot - } } // New creates a new CUDA library locator. -func New(opts ...Options) lookup.Locator { - c := &cudaLocator{} - for _, opt := range opts { - opt(c) +func New(libraries lookup.Locator) lookup.Locator { + c := cudaLocator{ + Locator: libraries, } - - if c.logger == nil { - c.logger = logger.New() - } - if c.driverRoot == "" { - c.driverRoot = "/" - } - - // TODO: Do we want to set the Count to 1 here? - l, _ := lookup.NewLibraryLocator( - c.logger, - c.driverRoot, - ) - - c.Locator = l - return c + return &c } // Locate returns the path to the libcuda.so.RMVERSION file. diff --git a/internal/lookup/cuda/cuda_test.go b/internal/lookup/cuda/cuda_test.go index 7151acdb..5db69c25 100644 --- a/internal/lookup/cuda/cuda_test.go +++ b/internal/lookup/cuda/cuda_test.go @@ -57,8 +57,10 @@ func TestLocate(t *testing.T) { require.NoError(t, err) l := New( - WithLogger(logger), - WithDriverRoot(driverRoot), + lookup.NewLibraryLocator( + lookup.WithLogger(logger), + lookup.WithRoot(driverRoot), + ), ) candidates, err := l.Locate(".*") diff --git a/internal/lookup/file.go b/internal/lookup/file.go index 20a1f164..8f330273 100644 --- a/internal/lookup/file.go +++ b/internal/lookup/file.go @@ -17,7 +17,6 @@ package lookup import ( - "errors" "fmt" "os" "path/filepath" @@ -25,55 +24,58 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) -// ErrNotFound indicates that a specified pattern or file could not be found. -var ErrNotFound = errors.New("not found") - // file can be used to locate file (or file-like elements) at a specified set of // prefixes. The validity of a file is determined by a filter function. type file struct { - logger logger.Interface - root string - prefixes []string - filter func(string) error - count int - isOptional bool + builder + prefixes []string } -// Option defines a function for passing options to the NewFileLocator() call -type Option func(*file) +// builder defines the builder for a file locator. +type builder struct { + logger logger.Interface + root string + searchPaths []string + filter func(string) error + count int + isOptional bool +} + +// Option defines a function for passing builder to the NewFileLocator() call +type Option func(*builder) // WithRoot sets the root for the file locator func WithRoot(root string) Option { - return func(f *file) { + return func(f *builder) { f.root = root } } // WithLogger sets the logger for the file locator func WithLogger(logger logger.Interface) Option { - return func(f *file) { + return func(f *builder) { f.logger = logger } } // WithSearchPaths sets the search paths for the file locator. func WithSearchPaths(paths ...string) Option { - return func(f *file) { - f.prefixes = paths + return func(f *builder) { + f.searchPaths = paths } } // WithFilter sets the filter for the file locator // The filter is called for each candidate file and candidates that return nil are considered. func WithFilter(assert func(string) error) Option { - return func(f *file) { + return func(f *builder) { f.filter = assert } } // WithCount sets the maximum number of candidates to discover func WithCount(count int) Option { - return func(f *file) { + return func(f *builder) { f.count = count } } @@ -81,32 +83,42 @@ func WithCount(count int) Option { // WithOptional sets the optional flag for the file locator // If the optional flag is set, the locator will not return an error if the file is not found. func WithOptional(optional bool) Option { - return func(f *file) { + return func(f *builder) { f.isOptional = optional } } -// NewFileLocator creates a Locator that can be used to find files with the specified options. +func newBuilder(opts ...Option) *builder { + o := &builder{} + for _, opt := range opts { + opt(o) + } + if o.logger == nil { + o.logger = logger.New() + } + if o.filter == nil { + o.filter = assertFile + } + return o +} + +func (o builder) build() *file { + f := file{ + builder: o, + // Since the `Locate` implementations rely on the root already being specified we update + // the prefixes to include the root. + prefixes: getSearchPrefixes(o.root, o.searchPaths...), + } + return &f +} + +// NewFileLocator creates a Locator that can be used to find files with the specified builder. func NewFileLocator(opts ...Option) Locator { return newFileLocator(opts...) } func newFileLocator(opts ...Option) *file { - f := &file{} - for _, opt := range opts { - opt(f) - } - if f.logger == nil { - f.logger = logger.New() - } - if f.filter == nil { - f.filter = assertFile - } - // Since the `Locate` implementations rely on the root already being specified we update - // the prefixes to include the root. - f.prefixes = getSearchPrefixes(f.root, f.prefixes...) - - return f + return newBuilder(opts...).build() } // getSearchPrefixes generates a list of unique paths to be searched by a file locator. diff --git a/internal/lookup/library.go b/internal/lookup/library.go index 73ad6984..7f5cf7c8 100644 --- a/internal/lookup/library.go +++ b/internal/lookup/library.go @@ -30,12 +30,20 @@ type ldcacheLocator struct { var _ Locator = (*ldcacheLocator)(nil) -// NewLibraryLocator creates a library locator using the specified logger. -func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { - // We construct a symlink locator for expected library locations. - symlinkLocator := NewSymlinkLocator( - WithLogger(logger), - WithRoot(root), +// NewLibraryLocator creates a library locator using the specified options. +func NewLibraryLocator(opts ...Option) Locator { + b := newBuilder(opts...) + + // If search paths are already specified, we return a locator for the specified search paths. + if len(b.searchPaths) > 0 { + return NewSymlinkLocator( + WithLogger(b.logger), + WithSearchPaths(b.searchPaths...), + WithRoot("/"), + ) + } + + opts = append(opts, WithSearchPaths([]string{ "/", "/usr/lib64", @@ -50,24 +58,28 @@ func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { "/lib/aarch64-linux-gnu/nvidia/current", }...), ) + // We construct a symlink locator for expected library locations. + symlinkLocator := NewSymlinkLocator(opts...) l := First( symlinkLocator, - newLdcacheLocator(logger, root), + newLdcacheLocator(opts...), ) - return l, nil + return l } -func newLdcacheLocator(logger logger.Interface, root string) Locator { - cache, err := ldcache.New(logger, root) +func newLdcacheLocator(opts ...Option) Locator { + b := newBuilder(opts...) + + cache, err := ldcache.New(b.logger, b.root) if err != nil { // If we failed to open the LDCache, we default to a symlink locator. - logger.Warningf("Failed to load ldcache: %v", err) + b.logger.Warningf("Failed to load ldcache: %v", err) return nil } return &ldcacheLocator{ - logger: logger, + logger: b.logger, cache: cache, } } @@ -82,7 +94,7 @@ func (l ldcacheLocator) Locate(libname string) ([]string, error) { } if len(paths64) == 0 { - return nil, fmt.Errorf("64-bit library %v: %w", libname, errNotFound) + return nil, fmt.Errorf("64-bit library %v: %w", libname, ErrNotFound) } return paths64, nil diff --git a/internal/lookup/library_test.go b/internal/lookup/library_test.go index 5682281d..60dba77c 100644 --- a/internal/lookup/library_test.go +++ b/internal/lookup/library_test.go @@ -43,7 +43,10 @@ func TestLDCacheLocator(t *testing.T) { require.NoError(t, os.Symlink(versionLib, sonameLink)) require.NoError(t, os.Symlink(sonameLink, soLink)) - lut := newLdcacheLocator(logger, testDir) + lut := newLdcacheLocator( + WithLogger(logger), + WithRoot(testDir), + ) testCases := []struct { description string @@ -63,7 +66,7 @@ func TestLDCacheLocator(t *testing.T) { { description: "lib only not in LDCache returns error", libname: "libnotcuda.so", - expectedError: errNotFound, + expectedError: ErrNotFound, }, } @@ -94,7 +97,6 @@ func TestLDCacheLocator(t *testing.T) { require.EqualValues(t, tc.expected, cleanedCandidates) }) } - } func TestLibraryLocator(t *testing.T) { @@ -125,14 +127,12 @@ func TestLibraryLocator(t *testing.T) { require.NoError(t, os.Symlink(libTarget1, source1)) require.NoError(t, os.Symlink(source1, source2)) - lut, err := NewLibraryLocator(logger, testDir) - require.NoError(t, err) - testCases := []struct { - description string - libname string - expected []string - expectedError error + description string + libname string + librarySearchPaths []string + expected []string + expectedError error }{ { description: "slash in path resoves symlink", @@ -152,7 +152,7 @@ func TestLibraryLocator(t *testing.T) { { description: "library not found returns error", libname: "/lib/symlink/libnotcuda.so", - expectedError: errNotFound, + expectedError: ErrNotFound, }, { description: "slash in path with pattern resoves symlink", @@ -172,10 +172,31 @@ func TestLibraryLocator(t *testing.T) { filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"), }, }, + { + description: "search paths are searched", + libname: "lib*.so.1.2.3", + librarySearchPaths: []string{filepath.Join(testDir, "/lib/symlink")}, + expected: []string{ + filepath.Join(testDir, "/lib/symlink/libcuda.so.1.2.3"), + filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"), + }, + }, + { + description: "search paths are absolute to root", + libname: "lib*.so.1.2.3", + librarySearchPaths: []string{"/lib/symlink"}, + expectedError: ErrNotFound, + }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + lut := NewLibraryLocator( + WithLogger(logger), + WithRoot(testDir), + WithSearchPaths(tc.librarySearchPaths...), + ) + candidates, err := lut.Locate(tc.libname) require.ErrorIs(t, err, tc.expectedError) diff --git a/internal/lookup/locator.go b/internal/lookup/locator.go index af5633fb..73ade232 100644 --- a/internal/lookup/locator.go +++ b/internal/lookup/locator.go @@ -25,4 +25,5 @@ type Locator interface { Locate(string) ([]string, error) } -var errNotFound = errors.New("not found") +// ErrNotFound indicates that a specified pattern or file could not be found. +var ErrNotFound = errors.New("not found") diff --git a/internal/lookup/root/root.go b/internal/lookup/root/root.go new file mode 100644 index 00000000..f96e6b99 --- /dev/null +++ b/internal/lookup/root/root.go @@ -0,0 +1,65 @@ +/** +# Copyright 2023 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package root + +import ( + "path/filepath" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" +) + +// Driver represents a filesystem in which a set of drivers or devices is defined. +type Driver struct { + logger logger.Interface + // Root represents the root from the perspective of the driver libraries and binaries. + Root string + // librarySearchPaths specifies explicit search paths for discovering libraries. + librarySearchPaths []string +} + +// New creates a new Driver root at the specified path. +// TODO: Use functional options here. +func New(logger logger.Interface, path string, librarySearchPaths []string) *Driver { + return &Driver{ + logger: logger, + Root: path, + librarySearchPaths: normalizeSearchPaths(librarySearchPaths...), + } +} + +// Drivers returns a Locator for driver libraries. +func (r *Driver) Libraries() lookup.Locator { + return lookup.NewLibraryLocator( + lookup.WithLogger(r.logger), + lookup.WithRoot(r.Root), + lookup.WithSearchPaths(r.librarySearchPaths...), + ) +} + +// normalizeSearchPaths takes a list of paths and normalized these. +// Each of the elements in the list is expanded if it is a path list and the +// resultant list is returned. +// This allows, for example, for the contents of `PATH` or `LD_LIBRARY_PATH` to +// be passed as a search path directly. +func normalizeSearchPaths(paths ...string) []string { + var normalized []string + for _, path := range paths { + normalized = append(normalized, filepath.SplitList(path)...) + } + return normalized +} diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index b4cbb1f6..c0a98d1c 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -23,6 +23,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) @@ -34,20 +35,21 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image imag return nil, nil } - driverRoot := cfg.NVIDIAContainerCLIConfig.Root + // TODO: We should not just pass `nil` as the search path here. + driver := root.New(logger, cfg.NVIDIAContainerCLIConfig.Root, nil) nvidiaCTKPath := cfg.NVIDIACTKConfig.Path mounts, err := discover.NewGraphicsMountsDiscoverer( logger, - driverRoot, + driver, nvidiaCTKPath, ) if err != nil { return nil, fmt.Errorf("failed to create mounts discoverer: %v", err) } - // In standard usage, the devRoot is the same as the driverRoot. - devRoot := driverRoot + // In standard usage, the devRoot is the same as the driver.Root. + devRoot := driver.Root drmNodes, err := discover.NewDRMNodesDiscoverer( logger, image.DevicesFromEnvvars(visibleDevicesEnvvar), diff --git a/pkg/nvcdi/common-nvml.go b/pkg/nvcdi/common-nvml.go index a33cfc61..4c634a72 100644 --- a/pkg/nvcdi/common-nvml.go +++ b/pkg/nvcdi/common-nvml.go @@ -36,12 +36,12 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) { }, ) - graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath) + graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.nvidiaCTKPath) if err != nil { l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err) } - driverFiles, err := NewDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) + driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.nvmllib) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err) } diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index cbc892af..28bd0704 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -27,12 +27,13 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "golang.org/x/sys/unix" ) // NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation. // The supplied NVML Library is used to query the expected driver version. -func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { +func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { if r := nvmllib.Init(); r != nvml.SUCCESS { return nil, fmt.Errorf("failed to initialize NVML: %v", r) } @@ -47,26 +48,26 @@ func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPa return nil, fmt.Errorf("failed to determine driver version: %v", r) } - return newDriverVersionDiscoverer(logger, driverRoot, nvidiaCTKPath, version) + return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, version) } -func newDriverVersionDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, version string) (discover.Discover, error) { - libraries, err := NewDriverLibraryDiscoverer(logger, driverRoot, nvidiaCTKPath, version) +func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) { + libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCTKPath, version) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err) } - ipcs, err := discover.NewIPCDiscoverer(logger, driverRoot) + ipcs, err := discover.NewIPCDiscoverer(logger, driver.Root) if err != nil { return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err) } - firmwares, err := NewDriverFirmwareDiscoverer(logger, driverRoot, version) + firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version) if err != nil { return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err) } - binaries := NewDriverBinariesDiscoverer(logger, driverRoot) + binaries := NewDriverBinariesDiscoverer(logger, driver.Root) d := discover.Merge( libraries, @@ -79,8 +80,8 @@ func newDriverVersionDiscoverer(logger logger.Interface, driverRoot string, nvid } // NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version. -func NewDriverLibraryDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, version string) (discover.Discover, error) { - libraryPaths, err := getVersionLibs(logger, driverRoot, version) +func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) { + libraryPaths, err := getVersionLibs(logger, driver, version) if err != nil { return nil, fmt.Errorf("failed to get libraries for driver version: %v", err) } @@ -89,9 +90,9 @@ func NewDriverLibraryDiscoverer(logger logger.Interface, driverRoot string, nvid logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), ), - driverRoot, + driver.Root, libraryPaths, ) @@ -185,12 +186,11 @@ func NewDriverBinariesDiscoverer(logger logger.Interface, driverRoot string) dis // getVersionLibs checks the LDCache for libraries ending in the specified driver version. // Although the ldcache at the specified driverRoot is queried, the paths are returned relative to this driverRoot. // This allows the standard mount location logic to be used for resolving the mounts. -func getVersionLibs(logger logger.Interface, driverRoot string, version string) ([]string, error) { +func getVersionLibs(logger logger.Interface, driver *root.Driver, version string) ([]string, error) { logger.Infof("Using driver version %v", version) libCudaPaths, err := cuda.New( - cuda.WithLogger(logger), - cuda.WithDriverRoot(driverRoot), + driver.Libraries(), ).Locate("." + version) if err != nil { return nil, fmt.Errorf("failed to locate libcuda.so.%v: %v", version, err) @@ -208,13 +208,13 @@ func getVersionLibs(logger logger.Interface, driverRoot string, version string) return nil, fmt.Errorf("failed to locate libraries for driver version %v: %v", version, err) } - if driverRoot == "/" || driverRoot == "" { + if driver.Root == "/" || driver.Root == "" { return libs, nil } var relative []string for _, l := range libs { - relative = append(relative, strings.TrimPrefix(l, driverRoot)) + relative = append(relative, strings.TrimPrefix(l, driver.Root)) } return relative, nil diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 8a60ac6b..3839697c 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -23,6 +23,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvlib/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" @@ -54,6 +55,7 @@ type nvcdilib struct { vendor string class string + driver *root.Driver infolib info.Interface mergedDeviceOptions []transform.MergedDeviceOption @@ -87,6 +89,9 @@ func New(opts ...Option) (Interface, error) { l.infolib = info.New() } + // TODO: We need to improve the construction of this driver root. + l.driver = root.New(l.logger, l.driverRoot, l.librarySearchPaths) + var lib Interface switch l.resolveMode() { case ModeCSV: diff --git a/pkg/nvcdi/management.go b/pkg/nvcdi/management.go index 9352e110..460a4873 100644 --- a/pkg/nvcdi/management.go +++ b/pkg/nvcdi/management.go @@ -65,7 +65,7 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) { return nil, fmt.Errorf("failed to get CUDA version: %v", err) } - driver, err := newDriverVersionDiscoverer(m.logger, m.driverRoot, m.nvidiaCTKPath, version) + driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, version) if err != nil { return nil, fmt.Errorf("failed to create driver library discoverer: %v", err) } @@ -86,8 +86,7 @@ func (m *managementlib) getCudaVersion() (string, error) { } libCudaPaths, err := cuda.New( - cuda.WithLogger(m.logger), - cuda.WithDriverRoot(m.driverRoot), + m.driver.Libraries(), ).Locate(".*.*") if err != nil { return "", fmt.Errorf("failed to locate libcuda.so: %v", err)