diff --git a/cmd/nvidia-ctk/cdi/generate/driver.go b/cmd/nvidia-ctk/cdi/generate/driver.go index c5aec328..0408514c 100644 --- a/cmd/nvidia-ctk/cdi/generate/driver.go +++ b/cmd/nvidia-ctk/cdi/generate/driver.go @@ -18,6 +18,7 @@ package generate import ( "fmt" + "path/filepath" "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" @@ -38,12 +39,62 @@ var _ discover.Discover = (*driverLibraries)(nil) // NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation. // The supplied NVML Library is used to query the expected driver version. func NewDriverDiscoverer(logger *logrus.Logger, root string, nvmllib nvml.Interface) (discover.Discover, error) { - libraries, err := NewDriverLibraryDiscoverer(logger, root, nvmllib) + version, r := nvmllib.SystemGetDriverVersion() + if r != nvml.SUCCESS { + return nil, fmt.Errorf("failed to determine driver version: %v", r) + } + + libraries, err := NewDriverLibraryDiscoverer(logger, root, version) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err) } - binaries := discover.NewMounts( + firmwares := NewDriverFirmwareDiscoverer(logger, root, version) + + binaries := NewDriverBinariesDiscoverer(logger, root) + + d := discover.Merge( + libraries, + firmwares, + binaries, + ) + + return d, nil +} + +// NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version. +func NewDriverLibraryDiscoverer(logger *logrus.Logger, root string, version string) (discover.Discover, error) { + libraries, err := findVersionLibs(logger, root, version) + if err != nil { + return nil, fmt.Errorf("failed to get libraries for driver version: %v", err) + } + + d := driverLibraries{ + logger: logger, + root: root, + libraries: libraries, + } + + return &d, nil +} + +// NewDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version. +func NewDriverFirmwareDiscoverer(logger *logrus.Logger, root string, version string) discover.Discover { + gspFirmwarePath := filepath.Join("/lib/firmware/nvidia", version, "gsp.bin") + return discover.NewMounts( + logger, + lookup.NewFileLocator( + lookup.WithLogger(logger), + lookup.WithRoot(root), + ), + root, + []string{gspFirmwarePath}, + ) +} + +// NewDriverBinariesDiscoverer creates a discoverer for GSP firmware associated with the GPU driver. +func NewDriverBinariesDiscoverer(logger *logrus.Logger, root string) discover.Discover { + return discover.NewMounts( logger, lookup.NewExecutableLocator(logger, root), root, @@ -55,34 +106,6 @@ func NewDriverDiscoverer(logger *logrus.Logger, root string, nvmllib nvml.Interf "nvidia-cuda-mps-server", /* Multi process service server */ }, ) - - d := discover.Merge( - libraries, - binaries, - ) - - return d, nil -} - -// NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version. -func NewDriverLibraryDiscoverer(logger *logrus.Logger, root string, nvmllib nvml.Interface) (discover.Discover, error) { - version, r := nvmllib.SystemGetDriverVersion() - if r != nvml.SUCCESS { - return nil, fmt.Errorf("failed to determine driver version: %v", r) - } - - libraries, err := findVersionLibs(logger, root, version) - if err != nil { - return nil, fmt.Errorf("failed to get libraries for driver version: %v", r) - } - - d := driverLibraries{ - logger: logger, - root: root, - libraries: libraries, - } - - return &d, nil } // Devices are empty for this discoverer