Merge branch 'library-search-path-cdi-generate' into 'main'

Allow search paths when locating libcuda.so

See merge request nvidia/container-toolkit/container-toolkit!462
This commit is contained in:
Evan Lezar 2023-11-22 19:49:15 +00:00
commit 1909b1fe60
13 changed files with 225 additions and 146 deletions

View File

@ -28,6 +28,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
) )
// NewDRMNodesDiscoverer returns a discoverrer for the DRM device nodes associated with the specified visible devices. // NewDRMNodesDiscoverer returns a discoverrer for the DRM device nodes associated with the specified visible devices.
@ -48,15 +49,11 @@ func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices
} }
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan. // NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) (Discover, error) { func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) {
locator, err := lookup.NewLibraryLocator(logger, driverRoot)
if err != nil {
return nil, fmt.Errorf("failed to construct library locator: %v", err)
}
libraries := NewMounts( libraries := NewMounts(
logger, logger,
locator, driver.Libraries(),
driverRoot, driver.Root,
[]string{ []string{
"libnvidia-egl-gbm.so", "libnvidia-egl-gbm.so",
}, },
@ -66,10 +63,10 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvi
logger, logger,
lookup.NewFileLocator( lookup.NewFileLocator(
lookup.WithLogger(logger), lookup.WithLogger(logger),
lookup.WithRoot(driverRoot), lookup.WithRoot(driver.Root),
lookup.WithSearchPaths("/etc", "/usr/share"), lookup.WithSearchPaths("/etc", "/usr/share"),
), ),
driverRoot, driver.Root,
[]string{ []string{
"glvnd/egl_vendor.d/10_nvidia.json", "glvnd/egl_vendor.d/10_nvidia.json",
"vulkan/icd.d/nvidia_icd.json", "vulkan/icd.d/nvidia_icd.json",
@ -81,7 +78,7 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driverRoot string, nvi
}, },
) )
xorg := optionalXorgDiscoverer(logger, driverRoot, nvidiaCTKPath) xorg := optionalXorgDiscoverer(logger, driver, nvidiaCTKPath)
discover := Merge( discover := Merge(
libraries, libraries,
@ -249,8 +246,8 @@ var _ Discover = (*xorgHooks)(nil)
// optionalXorgDiscoverer creates a discoverer for Xorg libraries. // optionalXorgDiscoverer creates a discoverer for Xorg libraries.
// If the creation of the discoverer fails, a None discoverer is returned. // If the creation of the discoverer fails, a None discoverer is returned.
func optionalXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) Discover { func optionalXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) Discover {
xorg, err := newXorgDiscoverer(logger, driverRoot, nvidiaCTKPath) xorg, err := newXorgDiscoverer(logger, driver, nvidiaCTKPath)
if err != nil { if err != nil {
logger.Warningf("Failed to create Xorg discoverer: %v; skipping xorg libraries", err) logger.Warningf("Failed to create Xorg discoverer: %v; skipping xorg libraries", err)
return None{} return None{}
@ -258,10 +255,9 @@ func optionalXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCT
return xorg return xorg
} }
func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string) (Discover, error) { func newXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) {
libCudaPaths, err := cuda.New( libCudaPaths, err := cuda.New(
cuda.WithLogger(logger), driver.Libraries(),
cuda.WithDriverRoot(driverRoot),
).Locate(".*.*") ).Locate(".*.*")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to locate libcuda.so: %v", err) return nil, fmt.Errorf("failed to locate libcuda.so: %v", err)
@ -278,11 +274,11 @@ func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath
logger, logger,
lookup.NewFileLocator( lookup.NewFileLocator(
lookup.WithLogger(logger), lookup.WithLogger(logger),
lookup.WithRoot(driverRoot), lookup.WithRoot(driver.Root),
lookup.WithSearchPaths(libRoot, "/usr/lib/x86_64-linux-gnu"), lookup.WithSearchPaths(libRoot, "/usr/lib/x86_64-linux-gnu"),
lookup.WithCount(1), lookup.WithCount(1),
), ),
driverRoot, driver.Root,
[]string{ []string{
"nvidia/xorg/nvidia_drv.so", "nvidia/xorg/nvidia_drv.so",
fmt.Sprintf("nvidia/xorg/libglxserver_nvidia.so.%s", version), fmt.Sprintf("nvidia/xorg/libglxserver_nvidia.so.%s", version),
@ -298,10 +294,10 @@ func newXorgDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath
logger, logger,
lookup.NewFileLocator( lookup.NewFileLocator(
lookup.WithLogger(logger), lookup.WithLogger(logger),
lookup.WithRoot(driverRoot), lookup.WithRoot(driver.Root),
lookup.WithSearchPaths("/usr/share"), lookup.WithSearchPaths("/usr/share"),
), ),
driverRoot, driver.Root,
[]string{"X11/xorg.conf.d/10-nvidia.conf"}, []string{"X11/xorg.conf.d/10-nvidia.conf"},
) )

View File

@ -17,55 +17,19 @@
package cuda package cuda
import ( import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
) )
type cudaLocator struct { type cudaLocator struct {
lookup.Locator lookup.Locator
logger logger.Interface
driverRoot string
}
// Options is a function that configures a cudaLocator.
type Options func(*cudaLocator)
// WithLogger is an option that configures the logger used by the locator.
func WithLogger(logger logger.Interface) Options {
return func(c *cudaLocator) {
c.logger = logger
}
}
// WithDriverRoot is an option that configures the driver root used by the locator.
func WithDriverRoot(driverRoot string) Options {
return func(c *cudaLocator) {
c.driverRoot = driverRoot
}
} }
// New creates a new CUDA library locator. // New creates a new CUDA library locator.
func New(opts ...Options) lookup.Locator { func New(libraries lookup.Locator) lookup.Locator {
c := &cudaLocator{} c := cudaLocator{
for _, opt := range opts { Locator: libraries,
opt(c)
} }
return &c
if c.logger == nil {
c.logger = logger.New()
}
if c.driverRoot == "" {
c.driverRoot = "/"
}
// TODO: Do we want to set the Count to 1 here?
l, _ := lookup.NewLibraryLocator(
c.logger,
c.driverRoot,
)
c.Locator = l
return c
} }
// Locate returns the path to the libcuda.so.RMVERSION file. // Locate returns the path to the libcuda.so.RMVERSION file.

View File

@ -57,8 +57,10 @@ func TestLocate(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
l := New( l := New(
WithLogger(logger), lookup.NewLibraryLocator(
WithDriverRoot(driverRoot), lookup.WithLogger(logger),
lookup.WithRoot(driverRoot),
),
) )
candidates, err := l.Locate(".*") candidates, err := l.Locate(".*")

View File

@ -17,7 +17,6 @@
package lookup package lookup
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -25,55 +24,58 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
) )
// ErrNotFound indicates that a specified pattern or file could not be found.
var ErrNotFound = errors.New("not found")
// file can be used to locate file (or file-like elements) at a specified set of // file can be used to locate file (or file-like elements) at a specified set of
// prefixes. The validity of a file is determined by a filter function. // prefixes. The validity of a file is determined by a filter function.
type file struct { type file struct {
logger logger.Interface builder
root string prefixes []string
prefixes []string
filter func(string) error
count int
isOptional bool
} }
// Option defines a function for passing options to the NewFileLocator() call // builder defines the builder for a file locator.
type Option func(*file) type builder struct {
logger logger.Interface
root string
searchPaths []string
filter func(string) error
count int
isOptional bool
}
// Option defines a function for passing builder to the NewFileLocator() call
type Option func(*builder)
// WithRoot sets the root for the file locator // WithRoot sets the root for the file locator
func WithRoot(root string) Option { func WithRoot(root string) Option {
return func(f *file) { return func(f *builder) {
f.root = root f.root = root
} }
} }
// WithLogger sets the logger for the file locator // WithLogger sets the logger for the file locator
func WithLogger(logger logger.Interface) Option { func WithLogger(logger logger.Interface) Option {
return func(f *file) { return func(f *builder) {
f.logger = logger f.logger = logger
} }
} }
// WithSearchPaths sets the search paths for the file locator. // WithSearchPaths sets the search paths for the file locator.
func WithSearchPaths(paths ...string) Option { func WithSearchPaths(paths ...string) Option {
return func(f *file) { return func(f *builder) {
f.prefixes = paths f.searchPaths = paths
} }
} }
// WithFilter sets the filter for the file locator // WithFilter sets the filter for the file locator
// The filter is called for each candidate file and candidates that return nil are considered. // The filter is called for each candidate file and candidates that return nil are considered.
func WithFilter(assert func(string) error) Option { func WithFilter(assert func(string) error) Option {
return func(f *file) { return func(f *builder) {
f.filter = assert f.filter = assert
} }
} }
// WithCount sets the maximum number of candidates to discover // WithCount sets the maximum number of candidates to discover
func WithCount(count int) Option { func WithCount(count int) Option {
return func(f *file) { return func(f *builder) {
f.count = count f.count = count
} }
} }
@ -81,32 +83,42 @@ func WithCount(count int) Option {
// WithOptional sets the optional flag for the file locator // WithOptional sets the optional flag for the file locator
// If the optional flag is set, the locator will not return an error if the file is not found. // If the optional flag is set, the locator will not return an error if the file is not found.
func WithOptional(optional bool) Option { func WithOptional(optional bool) Option {
return func(f *file) { return func(f *builder) {
f.isOptional = optional f.isOptional = optional
} }
} }
// NewFileLocator creates a Locator that can be used to find files with the specified options. func newBuilder(opts ...Option) *builder {
o := &builder{}
for _, opt := range opts {
opt(o)
}
if o.logger == nil {
o.logger = logger.New()
}
if o.filter == nil {
o.filter = assertFile
}
return o
}
func (o builder) build() *file {
f := file{
builder: o,
// Since the `Locate` implementations rely on the root already being specified we update
// the prefixes to include the root.
prefixes: getSearchPrefixes(o.root, o.searchPaths...),
}
return &f
}
// NewFileLocator creates a Locator that can be used to find files with the specified builder.
func NewFileLocator(opts ...Option) Locator { func NewFileLocator(opts ...Option) Locator {
return newFileLocator(opts...) return newFileLocator(opts...)
} }
func newFileLocator(opts ...Option) *file { func newFileLocator(opts ...Option) *file {
f := &file{} return newBuilder(opts...).build()
for _, opt := range opts {
opt(f)
}
if f.logger == nil {
f.logger = logger.New()
}
if f.filter == nil {
f.filter = assertFile
}
// Since the `Locate` implementations rely on the root already being specified we update
// the prefixes to include the root.
f.prefixes = getSearchPrefixes(f.root, f.prefixes...)
return f
} }
// getSearchPrefixes generates a list of unique paths to be searched by a file locator. // getSearchPrefixes generates a list of unique paths to be searched by a file locator.

View File

@ -30,12 +30,20 @@ type ldcacheLocator struct {
var _ Locator = (*ldcacheLocator)(nil) var _ Locator = (*ldcacheLocator)(nil)
// NewLibraryLocator creates a library locator using the specified logger. // NewLibraryLocator creates a library locator using the specified options.
func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { func NewLibraryLocator(opts ...Option) Locator {
// We construct a symlink locator for expected library locations. b := newBuilder(opts...)
symlinkLocator := NewSymlinkLocator(
WithLogger(logger), // If search paths are already specified, we return a locator for the specified search paths.
WithRoot(root), if len(b.searchPaths) > 0 {
return NewSymlinkLocator(
WithLogger(b.logger),
WithSearchPaths(b.searchPaths...),
WithRoot("/"),
)
}
opts = append(opts,
WithSearchPaths([]string{ WithSearchPaths([]string{
"/", "/",
"/usr/lib64", "/usr/lib64",
@ -50,24 +58,28 @@ func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) {
"/lib/aarch64-linux-gnu/nvidia/current", "/lib/aarch64-linux-gnu/nvidia/current",
}...), }...),
) )
// We construct a symlink locator for expected library locations.
symlinkLocator := NewSymlinkLocator(opts...)
l := First( l := First(
symlinkLocator, symlinkLocator,
newLdcacheLocator(logger, root), newLdcacheLocator(opts...),
) )
return l, nil return l
} }
func newLdcacheLocator(logger logger.Interface, root string) Locator { func newLdcacheLocator(opts ...Option) Locator {
cache, err := ldcache.New(logger, root) b := newBuilder(opts...)
cache, err := ldcache.New(b.logger, b.root)
if err != nil { if err != nil {
// If we failed to open the LDCache, we default to a symlink locator. // If we failed to open the LDCache, we default to a symlink locator.
logger.Warningf("Failed to load ldcache: %v", err) b.logger.Warningf("Failed to load ldcache: %v", err)
return nil return nil
} }
return &ldcacheLocator{ return &ldcacheLocator{
logger: logger, logger: b.logger,
cache: cache, cache: cache,
} }
} }
@ -82,7 +94,7 @@ func (l ldcacheLocator) Locate(libname string) ([]string, error) {
} }
if len(paths64) == 0 { if len(paths64) == 0 {
return nil, fmt.Errorf("64-bit library %v: %w", libname, errNotFound) return nil, fmt.Errorf("64-bit library %v: %w", libname, ErrNotFound)
} }
return paths64, nil return paths64, nil

View File

@ -43,7 +43,10 @@ func TestLDCacheLocator(t *testing.T) {
require.NoError(t, os.Symlink(versionLib, sonameLink)) require.NoError(t, os.Symlink(versionLib, sonameLink))
require.NoError(t, os.Symlink(sonameLink, soLink)) require.NoError(t, os.Symlink(sonameLink, soLink))
lut := newLdcacheLocator(logger, testDir) lut := newLdcacheLocator(
WithLogger(logger),
WithRoot(testDir),
)
testCases := []struct { testCases := []struct {
description string description string
@ -63,7 +66,7 @@ func TestLDCacheLocator(t *testing.T) {
{ {
description: "lib only not in LDCache returns error", description: "lib only not in LDCache returns error",
libname: "libnotcuda.so", libname: "libnotcuda.so",
expectedError: errNotFound, expectedError: ErrNotFound,
}, },
} }
@ -94,7 +97,6 @@ func TestLDCacheLocator(t *testing.T) {
require.EqualValues(t, tc.expected, cleanedCandidates) require.EqualValues(t, tc.expected, cleanedCandidates)
}) })
} }
} }
func TestLibraryLocator(t *testing.T) { func TestLibraryLocator(t *testing.T) {
@ -125,14 +127,12 @@ func TestLibraryLocator(t *testing.T) {
require.NoError(t, os.Symlink(libTarget1, source1)) require.NoError(t, os.Symlink(libTarget1, source1))
require.NoError(t, os.Symlink(source1, source2)) require.NoError(t, os.Symlink(source1, source2))
lut, err := NewLibraryLocator(logger, testDir)
require.NoError(t, err)
testCases := []struct { testCases := []struct {
description string description string
libname string libname string
expected []string librarySearchPaths []string
expectedError error expected []string
expectedError error
}{ }{
{ {
description: "slash in path resoves symlink", description: "slash in path resoves symlink",
@ -152,7 +152,7 @@ func TestLibraryLocator(t *testing.T) {
{ {
description: "library not found returns error", description: "library not found returns error",
libname: "/lib/symlink/libnotcuda.so", libname: "/lib/symlink/libnotcuda.so",
expectedError: errNotFound, expectedError: ErrNotFound,
}, },
{ {
description: "slash in path with pattern resoves symlink", description: "slash in path with pattern resoves symlink",
@ -172,10 +172,31 @@ func TestLibraryLocator(t *testing.T) {
filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"), filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"),
}, },
}, },
{
description: "search paths are searched",
libname: "lib*.so.1.2.3",
librarySearchPaths: []string{filepath.Join(testDir, "/lib/symlink")},
expected: []string{
filepath.Join(testDir, "/lib/symlink/libcuda.so.1.2.3"),
filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"),
},
},
{
description: "search paths are absolute to root",
libname: "lib*.so.1.2.3",
librarySearchPaths: []string{"/lib/symlink"},
expectedError: ErrNotFound,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
lut := NewLibraryLocator(
WithLogger(logger),
WithRoot(testDir),
WithSearchPaths(tc.librarySearchPaths...),
)
candidates, err := lut.Locate(tc.libname) candidates, err := lut.Locate(tc.libname)
require.ErrorIs(t, err, tc.expectedError) require.ErrorIs(t, err, tc.expectedError)

View File

@ -25,4 +25,5 @@ type Locator interface {
Locate(string) ([]string, error) Locate(string) ([]string, error)
} }
var errNotFound = errors.New("not found") // ErrNotFound indicates that a specified pattern or file could not be found.
var ErrNotFound = errors.New("not found")

View File

@ -0,0 +1,65 @@
/**
# Copyright 2023 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package root
import (
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
)
// Driver represents a filesystem in which a set of drivers or devices is defined.
type Driver struct {
logger logger.Interface
// Root represents the root from the perspective of the driver libraries and binaries.
Root string
// librarySearchPaths specifies explicit search paths for discovering libraries.
librarySearchPaths []string
}
// New creates a new Driver root at the specified path.
// TODO: Use functional options here.
func New(logger logger.Interface, path string, librarySearchPaths []string) *Driver {
return &Driver{
logger: logger,
Root: path,
librarySearchPaths: normalizeSearchPaths(librarySearchPaths...),
}
}
// Drivers returns a Locator for driver libraries.
func (r *Driver) Libraries() lookup.Locator {
return lookup.NewLibraryLocator(
lookup.WithLogger(r.logger),
lookup.WithRoot(r.Root),
lookup.WithSearchPaths(r.librarySearchPaths...),
)
}
// normalizeSearchPaths takes a list of paths and normalized these.
// Each of the elements in the list is expanded if it is a path list and the
// resultant list is returned.
// This allows, for example, for the contents of `PATH` or `LD_LIBRARY_PATH` to
// be passed as a search path directly.
func normalizeSearchPaths(paths ...string) []string {
var normalized []string
for _, path := range paths {
normalized = append(normalized, filepath.SplitList(path)...)
}
return normalized
}

View File

@ -23,6 +23,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
) )
@ -34,20 +35,21 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image imag
return nil, nil return nil, nil
} }
driverRoot := cfg.NVIDIAContainerCLIConfig.Root // TODO: We should not just pass `nil` as the search path here.
driver := root.New(logger, cfg.NVIDIAContainerCLIConfig.Root, nil)
nvidiaCTKPath := cfg.NVIDIACTKConfig.Path nvidiaCTKPath := cfg.NVIDIACTKConfig.Path
mounts, err := discover.NewGraphicsMountsDiscoverer( mounts, err := discover.NewGraphicsMountsDiscoverer(
logger, logger,
driverRoot, driver,
nvidiaCTKPath, nvidiaCTKPath,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create mounts discoverer: %v", err) return nil, fmt.Errorf("failed to create mounts discoverer: %v", err)
} }
// In standard usage, the devRoot is the same as the driverRoot. // In standard usage, the devRoot is the same as the driver.Root.
devRoot := driverRoot devRoot := driver.Root
drmNodes, err := discover.NewDRMNodesDiscoverer( drmNodes, err := discover.NewDRMNodesDiscoverer(
logger, logger,
image.DevicesFromEnvvars(visibleDevicesEnvvar), image.DevicesFromEnvvars(visibleDevicesEnvvar),

View File

@ -36,12 +36,12 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
}, },
) )
graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath) graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.nvidiaCTKPath)
if err != nil { if err != nil {
l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err) l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err)
} }
driverFiles, err := NewDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.nvmllib)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err) return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err)
} }

View File

@ -27,12 +27,13 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation. // NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
// The supplied NVML Library is used to query the expected driver version. // The supplied NVML Library is used to query the expected driver version.
func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) {
if r := nvmllib.Init(); r != nvml.SUCCESS { if r := nvmllib.Init(); r != nvml.SUCCESS {
return nil, fmt.Errorf("failed to initialize NVML: %v", r) return nil, fmt.Errorf("failed to initialize NVML: %v", r)
} }
@ -47,26 +48,26 @@ func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPa
return nil, fmt.Errorf("failed to determine driver version: %v", r) return nil, fmt.Errorf("failed to determine driver version: %v", r)
} }
return newDriverVersionDiscoverer(logger, driverRoot, nvidiaCTKPath, version) return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, version)
} }
func newDriverVersionDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, version string) (discover.Discover, error) { func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) {
libraries, err := NewDriverLibraryDiscoverer(logger, driverRoot, nvidiaCTKPath, version) libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCTKPath, version)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err) return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err)
} }
ipcs, err := discover.NewIPCDiscoverer(logger, driverRoot) ipcs, err := discover.NewIPCDiscoverer(logger, driver.Root)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err) return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err)
} }
firmwares, err := NewDriverFirmwareDiscoverer(logger, driverRoot, version) firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err) return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err)
} }
binaries := NewDriverBinariesDiscoverer(logger, driverRoot) binaries := NewDriverBinariesDiscoverer(logger, driver.Root)
d := discover.Merge( d := discover.Merge(
libraries, libraries,
@ -79,8 +80,8 @@ func newDriverVersionDiscoverer(logger logger.Interface, driverRoot string, nvid
} }
// NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version. // NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version.
func NewDriverLibraryDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, version string) (discover.Discover, error) { func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) {
libraryPaths, err := getVersionLibs(logger, driverRoot, version) libraryPaths, err := getVersionLibs(logger, driver, version)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get libraries for driver version: %v", err) return nil, fmt.Errorf("failed to get libraries for driver version: %v", err)
} }
@ -89,9 +90,9 @@ func NewDriverLibraryDiscoverer(logger logger.Interface, driverRoot string, nvid
logger, logger,
lookup.NewFileLocator( lookup.NewFileLocator(
lookup.WithLogger(logger), lookup.WithLogger(logger),
lookup.WithRoot(driverRoot), lookup.WithRoot(driver.Root),
), ),
driverRoot, driver.Root,
libraryPaths, libraryPaths,
) )
@ -185,12 +186,11 @@ func NewDriverBinariesDiscoverer(logger logger.Interface, driverRoot string) dis
// getVersionLibs checks the LDCache for libraries ending in the specified driver version. // getVersionLibs checks the LDCache for libraries ending in the specified driver version.
// Although the ldcache at the specified driverRoot is queried, the paths are returned relative to this driverRoot. // Although the ldcache at the specified driverRoot is queried, the paths are returned relative to this driverRoot.
// This allows the standard mount location logic to be used for resolving the mounts. // This allows the standard mount location logic to be used for resolving the mounts.
func getVersionLibs(logger logger.Interface, driverRoot string, version string) ([]string, error) { func getVersionLibs(logger logger.Interface, driver *root.Driver, version string) ([]string, error) {
logger.Infof("Using driver version %v", version) logger.Infof("Using driver version %v", version)
libCudaPaths, err := cuda.New( libCudaPaths, err := cuda.New(
cuda.WithLogger(logger), driver.Libraries(),
cuda.WithDriverRoot(driverRoot),
).Locate("." + version) ).Locate("." + version)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to locate libcuda.so.%v: %v", version, err) return nil, fmt.Errorf("failed to locate libcuda.so.%v: %v", version, err)
@ -208,13 +208,13 @@ func getVersionLibs(logger logger.Interface, driverRoot string, version string)
return nil, fmt.Errorf("failed to locate libraries for driver version %v: %v", version, err) return nil, fmt.Errorf("failed to locate libraries for driver version %v: %v", version, err)
} }
if driverRoot == "/" || driverRoot == "" { if driver.Root == "/" || driver.Root == "" {
return libs, nil return libs, nil
} }
var relative []string var relative []string
for _, l := range libs { for _, l := range libs {
relative = append(relative, strings.TrimPrefix(l, driverRoot)) relative = append(relative, strings.TrimPrefix(l, driver.Root))
} }
return relative, nil return relative, nil

View File

@ -23,6 +23,7 @@ import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvlib/pkg/nvml" "github.com/NVIDIA/go-nvlib/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
@ -54,6 +55,7 @@ type nvcdilib struct {
vendor string vendor string
class string class string
driver *root.Driver
infolib info.Interface infolib info.Interface
mergedDeviceOptions []transform.MergedDeviceOption mergedDeviceOptions []transform.MergedDeviceOption
@ -87,6 +89,9 @@ func New(opts ...Option) (Interface, error) {
l.infolib = info.New() l.infolib = info.New()
} }
// TODO: We need to improve the construction of this driver root.
l.driver = root.New(l.logger, l.driverRoot, l.librarySearchPaths)
var lib Interface var lib Interface
switch l.resolveMode() { switch l.resolveMode() {
case ModeCSV: case ModeCSV:

View File

@ -65,7 +65,7 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("failed to get CUDA version: %v", err) return nil, fmt.Errorf("failed to get CUDA version: %v", err)
} }
driver, err := newDriverVersionDiscoverer(m.logger, m.driverRoot, m.nvidiaCTKPath, version) driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, version)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create driver library discoverer: %v", err) return nil, fmt.Errorf("failed to create driver library discoverer: %v", err)
} }
@ -86,8 +86,7 @@ func (m *managementlib) getCudaVersion() (string, error) {
} }
libCudaPaths, err := cuda.New( libCudaPaths, err := cuda.New(
cuda.WithLogger(m.logger), m.driver.Libraries(),
cuda.WithDriverRoot(m.driverRoot),
).Locate(".*.*") ).Locate(".*.*")
if err != nil { if err != nil {
return "", fmt.Errorf("failed to locate libcuda.so: %v", err) return "", fmt.Errorf("failed to locate libcuda.so: %v", err)