diff --git a/go.mod b/go.mod index b3443804..9f8db1a3 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/NVIDIA/nvidia-container-toolkit go 1.20 require ( - github.com/NVIDIA/go-nvlib v0.4.0 + github.com/NVIDIA/go-nvlib v0.5.0 github.com/NVIDIA/go-nvml v0.12.0-6 github.com/fsnotify/fsnotify v1.7.0 github.com/opencontainers/runtime-spec v1.2.0 diff --git a/go.sum b/go.sum index c24a226c..7e6c3baf 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/NVIDIA/go-nvlib v0.4.0 h1:dvuqjjSamBODFuxttPg4H/xtNVQRZOSlwFtuNKybcGI= -github.com/NVIDIA/go-nvlib v0.4.0/go.mod h1:87z49ULPr4GWPSGfSIp3taU4XENRYN/enIg88MzcL4k= +github.com/NVIDIA/go-nvlib v0.5.0 h1:951KGrfr+p3cs89alO9z/ZxPPWKxwht9tx9rxiADoLI= +github.com/NVIDIA/go-nvlib v0.5.0/go.mod h1:87z49ULPr4GWPSGfSIp3taU4XENRYN/enIg88MzcL4k= github.com/NVIDIA/go-nvml v0.12.0-6 h1:FJYc2KrpvX+VOC/8QQvMiQMmZ/nPMRpdJO/Ik4xfcr0= github.com/NVIDIA/go-nvml v0.12.0-6/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= diff --git a/internal/info/additional_info_test.go b/internal/info/additional_info_test.go index 0a071188..53741322 100644 --- a/internal/info/additional_info_test.go +++ b/internal/info/additional_info_test.go @@ -243,7 +243,7 @@ func TestUsesNVGPUModule(t *testing.T) { t.Run(tc.description, func(t *testing.T) { sut := additionalInfo{ nvmllib: tc.nvmllib, - devicelib: device.New(device.WithNvml(tc.nvmllib)), + devicelib: device.New(tc.nvmllib), } flag, _ := sut.UsesNVGPUModule() diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index ff44b6dc..249bd311 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -99,7 +99,7 @@ func New(opts ...Option) (Interface, error) { l.nvmllib = nvml.New() } if l.devicelib == nil { - l.devicelib = device.New(device.WithNvml(l.nvmllib)) + l.devicelib = device.New(l.nvmllib) } if l.infolib == nil { l.infolib = info.New( diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 5a490619..417687b9 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -18,6 +18,7 @@ package nvcdi import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" + "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" @@ -34,6 +35,13 @@ func WithDeviceLib(devicelib device.Interface) Option { } } +// WithInfoLib sets the info library for CDI spec generation. +func WithInfoLib(infolib info.Interface) Option { + return func(l *nvcdilib) { + l.infolib = infolib + } +} + // WithDeviceNamers sets the device namer for the library func WithDeviceNamers(namers ...DeviceNamer) Option { return func(l *nvcdilib) { diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go index 11aa139d..c2a6517d 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/api.go @@ -38,7 +38,7 @@ type Interface interface { } type devicelib struct { - nvml nvml.Interface + nvmllib nvml.Interface skippedDevices map[string]struct{} verifySymbols *bool migProfiles []MigProfile @@ -47,14 +47,13 @@ type devicelib struct { var _ Interface = &devicelib{} // New creates a new instance of the 'device' interface. -func New(opts ...Option) Interface { - d := &devicelib{} +func New(nvmllib nvml.Interface, opts ...Option) Interface { + d := &devicelib{ + nvmllib: nvmllib, + } for _, opt := range opts { opt(d) } - if d.nvml == nil { - d.nvml = nvml.New() - } if d.verifySymbols == nil { verify := true d.verifySymbols = &verify @@ -68,13 +67,6 @@ func New(opts ...Option) Interface { return d } -// WithNvml provides an Option to set the NVML library used by the 'device' interface. -func WithNvml(nvml nvml.Interface) Option { - return func(d *devicelib) { - d.nvml = nvml - } -} - // WithVerifySymbols provides an option to toggle whether to verify select symbols exist in dynamic libraries before calling them. func WithVerifySymbols(verify bool) Option { return func(d *devicelib) { diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go index 10514591..5e1510ca 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/device.go @@ -51,7 +51,7 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) { // NewDeviceByUUID builds a new Device from a UUID. func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) { - dev, ret := d.nvml.DeviceGetHandleByUUID(uuid) + dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid) if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret) } @@ -334,13 +334,13 @@ func (d *device) isSkipped() (bool, error) { // VisitDevices visits each top-level device and invokes a callback function for it. func (d *devicelib) VisitDevices(visit func(int, Device) error) error { - count, ret := d.nvml.DeviceGetCount() + count, ret := d.nvmllib.DeviceGetCount() if ret != nvml.SUCCESS { return fmt.Errorf("error getting device count: %v", ret) } for i := 0; i < count; i++ { - device, ret := d.nvml.DeviceGetHandleByIndex(i) + device, ret := d.nvmllib.DeviceGetHandleByIndex(i) if ret != nvml.SUCCESS { return fmt.Errorf("error getting device handle for index '%v': %v", i, ret) } @@ -469,5 +469,5 @@ func (d *devicelib) hasSymbol(symbol string) bool { return true } - return d.nvml.Extensions().LookupSymbol(symbol) == nil + return d.nvmllib.Extensions().LookupSymbol(symbol) == nil } diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go index b02d4176..7145a06b 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/mig_device.go @@ -50,7 +50,7 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) { // NewMigDeviceByUUID builds a new MigDevice from a UUID. func (d *devicelib) NewMigDeviceByUUID(uuid string) (MigDevice, error) { - dev, ret := d.nvml.DeviceGetHandleByUUID(uuid) + dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid) if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret) } diff --git a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/info/builder.go b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/info/builder.go index 87f20f02..61684407 100644 --- a/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/info/builder.go +++ b/vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/info/builder.go @@ -55,7 +55,7 @@ func New(opts ...Option) Interface { ) } if o.devicelib == nil { - o.devicelib = device.New(device.WithNvml(o.nvmllib)) + o.devicelib = device.New(o.nvmllib) } if o.platform == "" { o.platform = PlatformAuto diff --git a/vendor/modules.txt b/vendor/modules.txt index fc3b76d4..e1cbbd9d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# github.com/NVIDIA/go-nvlib v0.4.0 +# github.com/NVIDIA/go-nvlib v0.5.0 ## explicit; go 1.20 github.com/NVIDIA/go-nvlib/pkg/nvlib/device github.com/NVIDIA/go-nvlib/pkg/nvlib/info