From 64f554ef419485b4bcb6de1471a7bf98f9c087ca Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 21 Nov 2023 16:55:26 +0100 Subject: [PATCH 1/5] Add builder for file locator Signed-off-by: Evan Lezar --- internal/lookup/file.go | 78 +++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/internal/lookup/file.go b/internal/lookup/file.go index 20a1f164..636ac315 100644 --- a/internal/lookup/file.go +++ b/internal/lookup/file.go @@ -31,49 +31,55 @@ var ErrNotFound = errors.New("not found") // file can be used to locate file (or file-like elements) at a specified set of // prefixes. The validity of a file is determined by a filter function. type file struct { - logger logger.Interface - root string - prefixes []string - filter func(string) error - count int - isOptional bool + builder + prefixes []string } -// Option defines a function for passing options to the NewFileLocator() call -type Option func(*file) +// builder defines the builder for a file locator. +type builder struct { + logger logger.Interface + root string + searchPaths []string + filter func(string) error + count int + isOptional bool +} + +// Option defines a function for passing builder to the NewFileLocator() call +type Option func(*builder) // WithRoot sets the root for the file locator func WithRoot(root string) Option { - return func(f *file) { + return func(f *builder) { f.root = root } } // WithLogger sets the logger for the file locator func WithLogger(logger logger.Interface) Option { - return func(f *file) { + return func(f *builder) { f.logger = logger } } // WithSearchPaths sets the search paths for the file locator. func WithSearchPaths(paths ...string) Option { - return func(f *file) { - f.prefixes = paths + return func(f *builder) { + f.searchPaths = paths } } // WithFilter sets the filter for the file locator // The filter is called for each candidate file and candidates that return nil are considered. func WithFilter(assert func(string) error) Option { - return func(f *file) { + return func(f *builder) { f.filter = assert } } // WithCount sets the maximum number of candidates to discover func WithCount(count int) Option { - return func(f *file) { + return func(f *builder) { f.count = count } } @@ -81,32 +87,42 @@ func WithCount(count int) Option { // WithOptional sets the optional flag for the file locator // If the optional flag is set, the locator will not return an error if the file is not found. func WithOptional(optional bool) Option { - return func(f *file) { + return func(f *builder) { f.isOptional = optional } } -// NewFileLocator creates a Locator that can be used to find files with the specified options. +func newBuilder(opts ...Option) *builder { + o := &builder{} + for _, opt := range opts { + opt(o) + } + if o.logger == nil { + o.logger = logger.New() + } + if o.filter == nil { + o.filter = assertFile + } + return o +} + +func (o builder) build() *file { + f := file{ + builder: o, + // Since the `Locate` implementations rely on the root already being specified we update + // the prefixes to include the root. + prefixes: getSearchPrefixes(o.root, o.searchPaths...), + } + return &f +} + +// NewFileLocator creates a Locator that can be used to find files with the specified builder. func NewFileLocator(opts ...Option) Locator { return newFileLocator(opts...) } func newFileLocator(opts ...Option) *file { - f := &file{} - for _, opt := range opts { - opt(f) - } - if f.logger == nil { - f.logger = logger.New() - } - if f.filter == nil { - f.filter = assertFile - } - // Since the `Locate` implementations rely on the root already being specified we update - // the prefixes to include the root. - f.prefixes = getSearchPrefixes(f.root, f.prefixes...) - - return f + return newBuilder(opts...).build() } // getSearchPrefixes generates a list of unique paths to be searched by a file locator. From 550588665505b7ec980c145b217d87585ffb46ae Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 15 Aug 2023 14:29:22 +0200 Subject: [PATCH 2/5] Use options for NewLibraryLocator Signed-off-by: Evan Lezar --- internal/discover/graphics.go | 8 ++++---- internal/lookup/cuda/cuda.go | 9 +++------ internal/lookup/library.go | 25 +++++++++++++------------ internal/lookup/library_test.go | 12 ++++++++---- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index 73145c89..aeaabe69 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -49,10 +49,10 @@ func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices // NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan. func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) (Discover, error) { - locator, err := lookup.NewLibraryLocator(logger, driverRoot) - if err != nil { - return nil, fmt.Errorf("failed to construct library locator: %v", err) - } + locator := lookup.NewLibraryLocator( + lookup.WithLogger(logger), + lookup.WithRoot(driverRoot), + ) libraries := NewMounts( logger, locator, diff --git a/internal/lookup/cuda/cuda.go b/internal/lookup/cuda/cuda.go index b95e81b4..0a8620bc 100644 --- a/internal/lookup/cuda/cuda.go +++ b/internal/lookup/cuda/cuda.go @@ -58,13 +58,10 @@ func New(opts ...Options) lookup.Locator { c.driverRoot = "/" } - // TODO: Do we want to set the Count to 1 here? - l, _ := lookup.NewLibraryLocator( - c.logger, - c.driverRoot, + c.Locator = lookup.NewLibraryLocator( + lookup.WithLogger(c.logger), + lookup.WithRoot(c.driverRoot), ) - - c.Locator = l return c } diff --git a/internal/lookup/library.go b/internal/lookup/library.go index 73ad6984..4a6bb722 100644 --- a/internal/lookup/library.go +++ b/internal/lookup/library.go @@ -30,12 +30,9 @@ type ldcacheLocator struct { var _ Locator = (*ldcacheLocator)(nil) -// NewLibraryLocator creates a library locator using the specified logger. -func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { - // We construct a symlink locator for expected library locations. - symlinkLocator := NewSymlinkLocator( - WithLogger(logger), - WithRoot(root), +// NewLibraryLocator creates a library locator using the specified options. +func NewLibraryLocator(opts ...Option) Locator { + opts = append(opts, WithSearchPaths([]string{ "/", "/usr/lib64", @@ -50,24 +47,28 @@ func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { "/lib/aarch64-linux-gnu/nvidia/current", }...), ) + // We construct a symlink locator for expected library locations. + symlinkLocator := NewSymlinkLocator(opts...) l := First( symlinkLocator, - newLdcacheLocator(logger, root), + newLdcacheLocator(opts...), ) - return l, nil + return l } -func newLdcacheLocator(logger logger.Interface, root string) Locator { - cache, err := ldcache.New(logger, root) +func newLdcacheLocator(opts ...Option) Locator { + b := newBuilder(opts...) + + cache, err := ldcache.New(b.logger, b.root) if err != nil { // If we failed to open the LDCache, we default to a symlink locator. - logger.Warningf("Failed to load ldcache: %v", err) + b.logger.Warningf("Failed to load ldcache: %v", err) return nil } return &ldcacheLocator{ - logger: logger, + logger: b.logger, cache: cache, } } diff --git a/internal/lookup/library_test.go b/internal/lookup/library_test.go index 5682281d..21485dca 100644 --- a/internal/lookup/library_test.go +++ b/internal/lookup/library_test.go @@ -43,7 +43,10 @@ func TestLDCacheLocator(t *testing.T) { require.NoError(t, os.Symlink(versionLib, sonameLink)) require.NoError(t, os.Symlink(sonameLink, soLink)) - lut := newLdcacheLocator(logger, testDir) + lut := newLdcacheLocator( + WithLogger(logger), + WithRoot(testDir), + ) testCases := []struct { description string @@ -94,7 +97,6 @@ func TestLDCacheLocator(t *testing.T) { require.EqualValues(t, tc.expected, cleanedCandidates) }) } - } func TestLibraryLocator(t *testing.T) { @@ -125,8 +127,10 @@ func TestLibraryLocator(t *testing.T) { require.NoError(t, os.Symlink(libTarget1, source1)) require.NoError(t, os.Symlink(source1, source2)) - lut, err := NewLibraryLocator(logger, testDir) - require.NoError(t, err) + lut := NewLibraryLocator( + WithLogger(logger), + WithRoot(testDir), + ) testCases := []struct { description string From e5391760e6e010558ec7a5f598eb5616164b12c8 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 21 Nov 2023 16:59:57 +0100 Subject: [PATCH 3/5] Remove duplicate not found error Signed-off-by: Evan Lezar --- internal/lookup/file.go | 4 ---- internal/lookup/library.go | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/internal/lookup/file.go b/internal/lookup/file.go index 636ac315..8f330273 100644 --- a/internal/lookup/file.go +++ b/internal/lookup/file.go @@ -17,7 +17,6 @@ package lookup import ( - "errors" "fmt" "os" "path/filepath" @@ -25,9 +24,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) -// ErrNotFound indicates that a specified pattern or file could not be found. -var ErrNotFound = errors.New("not found") - // file can be used to locate file (or file-like elements) at a specified set of // prefixes. The validity of a file is determined by a filter function. type file struct { diff --git a/internal/lookup/library.go b/internal/lookup/library.go index 4a6bb722..bdb2248e 100644 --- a/internal/lookup/library.go +++ b/internal/lookup/library.go @@ -83,7 +83,7 @@ func (l ldcacheLocator) Locate(libname string) ([]string, error) { } if len(paths64) == 0 { - return nil, fmt.Errorf("64-bit library %v: %w", libname, errNotFound) + return nil, fmt.Errorf("64-bit library %v: %w", libname, ErrNotFound) } return paths64, nil From f20ab793a20abdc8acf9a18dcd361c3c991e23fd Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 21 Nov 2023 17:01:10 +0100 Subject: [PATCH 4/5] Add support for specifying search paths Signed-off-by: Evan Lezar --- internal/lookup/library.go | 11 ++++++++++ internal/lookup/library_test.go | 39 +++++++++++++++++++++++---------- internal/lookup/locator.go | 3 ++- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/internal/lookup/library.go b/internal/lookup/library.go index bdb2248e..7f5cf7c8 100644 --- a/internal/lookup/library.go +++ b/internal/lookup/library.go @@ -32,6 +32,17 @@ var _ Locator = (*ldcacheLocator)(nil) // NewLibraryLocator creates a library locator using the specified options. func NewLibraryLocator(opts ...Option) Locator { + b := newBuilder(opts...) + + // If search paths are already specified, we return a locator for the specified search paths. + if len(b.searchPaths) > 0 { + return NewSymlinkLocator( + WithLogger(b.logger), + WithSearchPaths(b.searchPaths...), + WithRoot("/"), + ) + } + opts = append(opts, WithSearchPaths([]string{ "/", diff --git a/internal/lookup/library_test.go b/internal/lookup/library_test.go index 21485dca..60dba77c 100644 --- a/internal/lookup/library_test.go +++ b/internal/lookup/library_test.go @@ -66,7 +66,7 @@ func TestLDCacheLocator(t *testing.T) { { description: "lib only not in LDCache returns error", libname: "libnotcuda.so", - expectedError: errNotFound, + expectedError: ErrNotFound, }, } @@ -127,16 +127,12 @@ func TestLibraryLocator(t *testing.T) { require.NoError(t, os.Symlink(libTarget1, source1)) require.NoError(t, os.Symlink(source1, source2)) - lut := NewLibraryLocator( - WithLogger(logger), - WithRoot(testDir), - ) - testCases := []struct { - description string - libname string - expected []string - expectedError error + description string + libname string + librarySearchPaths []string + expected []string + expectedError error }{ { description: "slash in path resoves symlink", @@ -156,7 +152,7 @@ func TestLibraryLocator(t *testing.T) { { description: "library not found returns error", libname: "/lib/symlink/libnotcuda.so", - expectedError: errNotFound, + expectedError: ErrNotFound, }, { description: "slash in path with pattern resoves symlink", @@ -176,10 +172,31 @@ func TestLibraryLocator(t *testing.T) { filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"), }, }, + { + description: "search paths are searched", + libname: "lib*.so.1.2.3", + librarySearchPaths: []string{filepath.Join(testDir, "/lib/symlink")}, + expected: []string{ + filepath.Join(testDir, "/lib/symlink/libcuda.so.1.2.3"), + filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"), + }, + }, + { + description: "search paths are absolute to root", + libname: "lib*.so.1.2.3", + librarySearchPaths: []string{"/lib/symlink"}, + expectedError: ErrNotFound, + }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + lut := NewLibraryLocator( + WithLogger(logger), + WithRoot(testDir), + WithSearchPaths(tc.librarySearchPaths...), + ) + candidates, err := lut.Locate(tc.libname) require.ErrorIs(t, err, tc.expectedError) diff --git a/internal/lookup/locator.go b/internal/lookup/locator.go index af5633fb..73ade232 100644 --- a/internal/lookup/locator.go +++ b/internal/lookup/locator.go @@ -25,4 +25,5 @@ type Locator interface { Locate(string) ([]string, error) } -var errNotFound = errors.New("not found") +// ErrNotFound indicates that a specified pattern or file could not be found. +var ErrNotFound = errors.New("not found") From bbd92222063d2211e33c5393b97415f60d23c021 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 21 Nov 2023 16:08:16 +0100 Subject: [PATCH 5/5] Add driver root abstraction This change adds a driver root abstraction that defines how libraries are located relative to the root. This allows for this driver root to be constructed once and passed to discovery code. Signed-off-by: Evan Lezar --- internal/discover/graphics.go | 34 +++++++--------- internal/lookup/cuda/cuda.go | 41 ++----------------- internal/lookup/cuda/cuda_test.go | 6 ++- internal/lookup/root/root.go | 65 +++++++++++++++++++++++++++++++ internal/modifier/graphics.go | 10 +++-- pkg/nvcdi/common-nvml.go | 4 +- pkg/nvcdi/driver-nvml.go | 32 +++++++-------- pkg/nvcdi/lib.go | 5 +++ pkg/nvcdi/management.go | 5 +-- 9 files changed, 119 insertions(+), 83 deletions(-) create mode 100644 internal/lookup/root/root.go diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index aeaabe69..76badbf0 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -28,6 +28,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" ) // NewDRMNodesDiscoverer returns a discoverrer for the DRM device nodes associated with the specified visible devices. @@ -48,15 +49,11 @@ func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices } // NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan. -func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) (Discover, error) { - locator := lookup.NewLibraryLocator( - lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), - ) +func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) { libraries := NewMounts( logger, - locator, - driverRoot, + driver.Libraries(), + driver.Root, []string{ "libnvidia-egl-gbm.so", }, @@ -66,10 +63,10 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvi logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), lookup.WithSearchPaths("/etc", "/usr/share"), ), - driverRoot, + driver.Root, []string{ "glvnd/egl_vendor.d/10_nvidia.json", "vulkan/icd.d/nvidia_icd.json", @@ -79,7 +76,7 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvi }, ) - xorg := optionalXorgDiscoverer(logger, driverRoot, nvidiaCTKPath) + xorg := optionalXorgDiscoverer(logger, driver, nvidiaCTKPath) discover := Merge( libraries, @@ -247,8 +244,8 @@ var _ Discover = (*xorgHooks)(nil) // optionalXorgDiscoverer creates a discoverer for Xorg libraries. // If the creation of the discoverer fails, a None discoverer is returned. -func optionalXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) Discover { - xorg, err := newXorgDiscoverer(logger, driverRoot, nvidiaCTKPath) +func optionalXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) Discover { + xorg, err := newXorgDiscoverer(logger, driver, nvidiaCTKPath) if err != nil { logger.Warningf("Failed to create Xorg discoverer: %v; skipping xorg libraries", err) return None{} @@ -256,10 +253,9 @@ func optionalXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCT return xorg } -func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) (Discover, error) { +func newXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) { libCudaPaths, err := cuda.New( - cuda.WithLogger(logger), - cuda.WithDriverRoot(driverRoot), + driver.Libraries(), ).Locate(".*.*") if err != nil { return nil, fmt.Errorf("failed to locate libcuda.so: %v", err) @@ -276,11 +272,11 @@ func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), lookup.WithSearchPaths(libRoot, "/usr/lib/x86_64-linux-gnu"), lookup.WithCount(1), ), - driverRoot, + driver.Root, []string{ "nvidia/xorg/nvidia_drv.so", fmt.Sprintf("nvidia/xorg/libglxserver_nvidia.so.%s", version), @@ -296,10 +292,10 @@ func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), lookup.WithSearchPaths("/usr/share"), ), - driverRoot, + driver.Root, []string{"X11/xorg.conf.d/10-nvidia.conf"}, ) diff --git a/internal/lookup/cuda/cuda.go b/internal/lookup/cuda/cuda.go index 0a8620bc..68c4db35 100644 --- a/internal/lookup/cuda/cuda.go +++ b/internal/lookup/cuda/cuda.go @@ -17,52 +17,19 @@ package cuda import ( - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" ) type cudaLocator struct { lookup.Locator - logger logger.Interface - driverRoot string -} - -// Options is a function that configures a cudaLocator. -type Options func(*cudaLocator) - -// WithLogger is an option that configures the logger used by the locator. -func WithLogger(logger logger.Interface) Options { - return func(c *cudaLocator) { - c.logger = logger - } -} - -// WithDriverRoot is an option that configures the driver root used by the locator. -func WithDriverRoot(driverRoot string) Options { - return func(c *cudaLocator) { - c.driverRoot = driverRoot - } } // New creates a new CUDA library locator. -func New(opts ...Options) lookup.Locator { - c := &cudaLocator{} - for _, opt := range opts { - opt(c) +func New(libraries lookup.Locator) lookup.Locator { + c := cudaLocator{ + Locator: libraries, } - - if c.logger == nil { - c.logger = logger.New() - } - if c.driverRoot == "" { - c.driverRoot = "/" - } - - c.Locator = lookup.NewLibraryLocator( - lookup.WithLogger(c.logger), - lookup.WithRoot(c.driverRoot), - ) - return c + return &c } // Locate returns the path to the libcuda.so.RMVERSION file. diff --git a/internal/lookup/cuda/cuda_test.go b/internal/lookup/cuda/cuda_test.go index 7151acdb..5db69c25 100644 --- a/internal/lookup/cuda/cuda_test.go +++ b/internal/lookup/cuda/cuda_test.go @@ -57,8 +57,10 @@ func TestLocate(t *testing.T) { require.NoError(t, err) l := New( - WithLogger(logger), - WithDriverRoot(driverRoot), + lookup.NewLibraryLocator( + lookup.WithLogger(logger), + lookup.WithRoot(driverRoot), + ), ) candidates, err := l.Locate(".*") diff --git a/internal/lookup/root/root.go b/internal/lookup/root/root.go new file mode 100644 index 00000000..f96e6b99 --- /dev/null +++ b/internal/lookup/root/root.go @@ -0,0 +1,65 @@ +/** +# Copyright 2023 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package root + +import ( + "path/filepath" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" +) + +// Driver represents a filesystem in which a set of drivers or devices is defined. +type Driver struct { + logger logger.Interface + // Root represents the root from the perspective of the driver libraries and binaries. + Root string + // librarySearchPaths specifies explicit search paths for discovering libraries. + librarySearchPaths []string +} + +// New creates a new Driver root at the specified path. +// TODO: Use functional options here. +func New(logger logger.Interface, path string, librarySearchPaths []string) *Driver { + return &Driver{ + logger: logger, + Root: path, + librarySearchPaths: normalizeSearchPaths(librarySearchPaths...), + } +} + +// Drivers returns a Locator for driver libraries. +func (r *Driver) Libraries() lookup.Locator { + return lookup.NewLibraryLocator( + lookup.WithLogger(r.logger), + lookup.WithRoot(r.Root), + lookup.WithSearchPaths(r.librarySearchPaths...), + ) +} + +// normalizeSearchPaths takes a list of paths and normalized these. +// Each of the elements in the list is expanded if it is a path list and the +// resultant list is returned. +// This allows, for example, for the contents of `PATH` or `LD_LIBRARY_PATH` to +// be passed as a search path directly. +func normalizeSearchPaths(paths ...string) []string { + var normalized []string + for _, path := range paths { + normalized = append(normalized, filepath.SplitList(path)...) + } + return normalized +} diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index b4cbb1f6..c0a98d1c 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -23,6 +23,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) @@ -34,20 +35,21 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image imag return nil, nil } - driverRoot := cfg.NVIDIAContainerCLIConfig.Root + // TODO: We should not just pass `nil` as the search path here. + driver := root.New(logger, cfg.NVIDIAContainerCLIConfig.Root, nil) nvidiaCTKPath := cfg.NVIDIACTKConfig.Path mounts, err := discover.NewGraphicsMountsDiscoverer( logger, - driverRoot, + driver, nvidiaCTKPath, ) if err != nil { return nil, fmt.Errorf("failed to create mounts discoverer: %v", err) } - // In standard usage, the devRoot is the same as the driverRoot. - devRoot := driverRoot + // In standard usage, the devRoot is the same as the driver.Root. + devRoot := driver.Root drmNodes, err := discover.NewDRMNodesDiscoverer( logger, image.DevicesFromEnvvars(visibleDevicesEnvvar), diff --git a/pkg/nvcdi/common-nvml.go b/pkg/nvcdi/common-nvml.go index a33cfc61..4c634a72 100644 --- a/pkg/nvcdi/common-nvml.go +++ b/pkg/nvcdi/common-nvml.go @@ -36,12 +36,12 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) { }, ) - graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath) + graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.nvidiaCTKPath) if err != nil { l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err) } - driverFiles, err := NewDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) + driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.nvmllib) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err) } diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index cbc892af..28bd0704 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -27,12 +27,13 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "golang.org/x/sys/unix" ) // NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation. // The supplied NVML Library is used to query the expected driver version. -func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { +func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { if r := nvmllib.Init(); r != nvml.SUCCESS { return nil, fmt.Errorf("failed to initialize NVML: %v", r) } @@ -47,26 +48,26 @@ func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPa return nil, fmt.Errorf("failed to determine driver version: %v", r) } - return newDriverVersionDiscoverer(logger, driverRoot, nvidiaCTKPath, version) + return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, version) } -func newDriverVersionDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, version string) (discover.Discover, error) { - libraries, err := NewDriverLibraryDiscoverer(logger, driverRoot, nvidiaCTKPath, version) +func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) { + libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCTKPath, version) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err) } - ipcs, err := discover.NewIPCDiscoverer(logger, driverRoot) + ipcs, err := discover.NewIPCDiscoverer(logger, driver.Root) if err != nil { return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err) } - firmwares, err := NewDriverFirmwareDiscoverer(logger, driverRoot, version) + firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version) if err != nil { return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err) } - binaries := NewDriverBinariesDiscoverer(logger, driverRoot) + binaries := NewDriverBinariesDiscoverer(logger, driver.Root) d := discover.Merge( libraries, @@ -79,8 +80,8 @@ func newDriverVersionDiscoverer(logger logger.Interface, driverRoot string, nvid } // NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version. -func NewDriverLibraryDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, version string) (discover.Discover, error) { - libraryPaths, err := getVersionLibs(logger, driverRoot, version) +func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) { + libraryPaths, err := getVersionLibs(logger, driver, version) if err != nil { return nil, fmt.Errorf("failed to get libraries for driver version: %v", err) } @@ -89,9 +90,9 @@ func NewDriverLibraryDiscoverer(logger logger.Interface, driverRoot string, nvid logger, lookup.NewFileLocator( lookup.WithLogger(logger), - lookup.WithRoot(driverRoot), + lookup.WithRoot(driver.Root), ), - driverRoot, + driver.Root, libraryPaths, ) @@ -185,12 +186,11 @@ func NewDriverBinariesDiscoverer(logger logger.Interface, driverRoot string) dis // getVersionLibs checks the LDCache for libraries ending in the specified driver version. // Although the ldcache at the specified driverRoot is queried, the paths are returned relative to this driverRoot. // This allows the standard mount location logic to be used for resolving the mounts. -func getVersionLibs(logger logger.Interface, driverRoot string, version string) ([]string, error) { +func getVersionLibs(logger logger.Interface, driver *root.Driver, version string) ([]string, error) { logger.Infof("Using driver version %v", version) libCudaPaths, err := cuda.New( - cuda.WithLogger(logger), - cuda.WithDriverRoot(driverRoot), + driver.Libraries(), ).Locate("." + version) if err != nil { return nil, fmt.Errorf("failed to locate libcuda.so.%v: %v", version, err) @@ -208,13 +208,13 @@ func getVersionLibs(logger logger.Interface, driverRoot string, version string) return nil, fmt.Errorf("failed to locate libraries for driver version %v: %v", version, err) } - if driverRoot == "/" || driverRoot == "" { + if driver.Root == "/" || driver.Root == "" { return libs, nil } var relative []string for _, l := range libs { - relative = append(relative, strings.TrimPrefix(l, driverRoot)) + relative = append(relative, strings.TrimPrefix(l, driver.Root)) } return relative, nil diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 8a60ac6b..3839697c 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -23,6 +23,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvlib/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" @@ -54,6 +55,7 @@ type nvcdilib struct { vendor string class string + driver *root.Driver infolib info.Interface mergedDeviceOptions []transform.MergedDeviceOption @@ -87,6 +89,9 @@ func New(opts ...Option) (Interface, error) { l.infolib = info.New() } + // TODO: We need to improve the construction of this driver root. + l.driver = root.New(l.logger, l.driverRoot, l.librarySearchPaths) + var lib Interface switch l.resolveMode() { case ModeCSV: diff --git a/pkg/nvcdi/management.go b/pkg/nvcdi/management.go index 9352e110..460a4873 100644 --- a/pkg/nvcdi/management.go +++ b/pkg/nvcdi/management.go @@ -65,7 +65,7 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) { return nil, fmt.Errorf("failed to get CUDA version: %v", err) } - driver, err := newDriverVersionDiscoverer(m.logger, m.driverRoot, m.nvidiaCTKPath, version) + driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, version) if err != nil { return nil, fmt.Errorf("failed to create driver library discoverer: %v", err) } @@ -86,8 +86,7 @@ func (m *managementlib) getCudaVersion() (string, error) { } libCudaPaths, err := cuda.New( - cuda.WithLogger(m.logger), - cuda.WithDriverRoot(m.driverRoot), + m.driver.Libraries(), ).Locate(".*.*") if err != nil { return "", fmt.Errorf("failed to locate libcuda.so: %v", err)