diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index d461aa51..01c22ff3 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -103,7 +103,7 @@ func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) { return l.GetDeviceSpecsBy(identifiers...) } -// GetDeviceSpecsBy is not supported for the gdslib specs. +// GetDeviceSpecsBy returns the device specs for devices with the specified identifiers. func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.Device, error) { for _, id := range identifiers { if id == "all" { @@ -118,10 +118,23 @@ func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.De } defer func() { if r := l.nvmllib.Shutdown(); r != nvml.SUCCESS { - l.logger.Warningf("failed to shutdown NVML: %w", r) + l.logger.Warningf("failed to shutdown NVML: %v", r) } }() + if l.nvsandboxutilslib != nil { + if r := l.nvsandboxutilslib.Init(l.driverRoot); r != nvsandboxutils.SUCCESS { + l.logger.Warningf("Failed to init nvsandboxutils: %v; ignoring", r) + l.nvsandboxutilslib = nil + } + defer func() { + if l.nvsandboxutilslib == nil { + return + } + _ = l.nvsandboxutilslib.Shutdown() + }() + } + nvmlDevices, err := l.getNVMLDevicesByID(identifiers...) if err != nil { return nil, fmt.Errorf("failed to get NVML device handles: %w", err)