diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 4753cf71..ff44b6dc 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -81,14 +81,25 @@ func New(opts ...Option) (Interface, error) { indexNamer, _ := NewDeviceNamer(DeviceNameStrategyIndex) l.deviceNamers = []DeviceNamer{indexNamer} } + if l.nvidiaCDIHookPath == "" { + l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" + } if l.driverRoot == "" { l.driverRoot = "/" } if l.devRoot == "" { l.devRoot = l.driverRoot } - if l.nvidiaCDIHookPath == "" { - l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" + l.driver = root.New( + root.WithLogger(l.logger), + root.WithDriverRoot(l.driverRoot), + root.WithLibrarySearchPaths(l.librarySearchPaths...), + ) + if l.nvmllib == nil { + l.nvmllib = nvml.New() + } + if l.devicelib == nil { + l.devicelib = device.New(device.WithNvml(l.nvmllib)) } if l.infolib == nil { l.infolib = info.New( @@ -99,12 +110,6 @@ func New(opts ...Option) (Interface, error) { ) } - l.driver = root.New( - root.WithLogger(l.logger), - root.WithDriverRoot(l.driverRoot), - root.WithLibrarySearchPaths(l.librarySearchPaths...), - ) - var lib Interface switch l.resolveMode() { case ModeCSV: @@ -118,13 +123,6 @@ func New(opts ...Option) (Interface, error) { } lib = (*managementlib)(l) case ModeNvml: - if l.nvmllib == nil { - l.nvmllib = nvml.New() - } - if l.devicelib == nil { - l.devicelib = device.New(device.WithNvml(l.nvmllib)) - } - lib = (*nvmllib)(l) case ModeWsl: lib = (*wsllib)(l)