diff --git a/cmd/nvidia-ctk/cdi/generate/driver.go b/cmd/nvidia-ctk/cdi/generate/driver.go index c89718cb..bea3068c 100644 --- a/cmd/nvidia-ctk/cdi/generate/driver.go +++ b/cmd/nvidia-ctk/cdi/generate/driver.go @@ -28,15 +28,6 @@ import ( "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) -type driverLibraries struct { - logger *logrus.Logger - root string - nvidiaCTKPath string - libraries []string -} - -var _ discover.Discover = (*driverLibraries)(nil) - // NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation. // The supplied NVML Library is used to query the expected driver version. func NewDriverDiscoverer(logger *logrus.Logger, root string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { @@ -64,20 +55,34 @@ func NewDriverDiscoverer(logger *logrus.Logger, root string, nvidiaCTKPath strin } // NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version. -func NewDriverLibraryDiscoverer(logger *logrus.Logger, root string, nvidiaCTKPath, version string) (discover.Discover, error) { - libraries, err := findVersionLibs(logger, root, version) +func NewDriverLibraryDiscoverer(logger *logrus.Logger, root string, nvidiaCTKPath string, version string) (discover.Discover, error) { + libraryPaths, err := getVersionLibs(logger, root, version) if err != nil { return nil, fmt.Errorf("failed to get libraries for driver version: %v", err) } - d := driverLibraries{ - logger: logger, - root: root, - nvidiaCTKPath: nvidiaCTKPath, - libraries: libraries, - } + libraries := discover.NewMounts( + logger, + lookup.NewFileLocator( + lookup.WithLogger(logger), + lookup.WithRoot(root), + ), + root, + libraryPaths, + ) - return &d, nil + cfg := &discover.Config{ + Root: root, + NvidiaCTKPath: nvidiaCTKPath, + } + hooks, _ := discover.NewLDCacheUpdateHook(logger, libraries, cfg) + + d := discover.Merge( + libraries, + hooks, + ) + + return d, nil } // NewDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version. @@ -110,36 +115,10 @@ func NewDriverBinariesDiscoverer(logger *logrus.Logger, root string) discover.Di ) } -// Devices are empty for this discoverer -func (d *driverLibraries) Devices() ([]discover.Device, error) { - return nil, nil -} - -// Mounts returns the mounts for the driver libraries -func (d *driverLibraries) Mounts() ([]discover.Mount, error) { - var mounts []discover.Mount - for _, d := range d.libraries { - mount := discover.Mount{ - HostPath: d, - Path: d, - } - mounts = append(mounts, mount) - } - - return mounts, nil -} - -// Hooks returns a hook that updates the LDCache for the specified driver library paths. -func (d *driverLibraries) Hooks() ([]discover.Hook, error) { - hook := discover.CreateLDCacheUpdateHook( - d.nvidiaCTKPath, - d.libraries, - ) - - return []discover.Hook{hook}, nil -} - -func findVersionLibs(logger *logrus.Logger, root string, version string) ([]string, error) { +// getVersionLibs checks the LDCache for libraries ending in the specified driver version. +// Although the ldcache at the specified root is queried, the paths are returned relative to this root. +// This allows the standard mount location logic to be used for resolving the mounts. +func getVersionLibs(logger *logrus.Logger, root string, version string) ([]string, error) { logger.Infof("Using driver version %v", version) cache, err := ldcache.New(logger, root) @@ -164,5 +143,14 @@ func findVersionLibs(logger *logrus.Logger, root string, version string) ([]stri } } - return libs, nil + if root == "/" || root == "" { + return libs, nil + } + + var relative []string + for _, l := range libs { + relative = append(relative, strings.TrimPrefix(l, root)) + } + + return relative, nil }