diff --git a/pkg/nvlib/device/api.go b/pkg/nvlib/device/api.go index 498bda8..7741915 100644 --- a/pkg/nvlib/device/api.go +++ b/pkg/nvlib/device/api.go @@ -35,8 +35,8 @@ type Interface interface { } type devicelib struct { - nvml nvml.Interface - selectedDeviceClasses map[Class]struct{} + nvml nvml.Interface + skippedDevices map[string]struct{} } var _ Interface = &devicelib{} @@ -50,10 +50,8 @@ func New(opts ...Option) Interface { if d.nvml == nil { d.nvml = nvml.New() } - if d.selectedDeviceClasses == nil { - d.selectedDeviceClasses = map[Class]struct{}{ - ClassCompute: {}, - } + if d.skippedDevices == nil { + WithSkippedDevices("NVIDIA DGX Display")(d) } return d } @@ -65,14 +63,14 @@ func WithNvml(nvml nvml.Interface) Option { } } -// WithSelectedDeviceClasses selects the specified device classes when filtering devices -func WithSelectedDeviceClasses(classes ...Class) Option { +// WithSkippedDevices provides an Option to set devices to be skipped by model name +func WithSkippedDevices(names ...string) Option { return func(d *devicelib) { - if d.selectedDeviceClasses == nil { - d.selectedDeviceClasses = make(map[Class]struct{}) + if d.skippedDevices == nil { + d.skippedDevices = make(map[string]struct{}) } - for _, c := range classes { - d.selectedDeviceClasses[c] = struct{}{} + for _, name := range names { + d.skippedDevices[name] = struct{}{} } } } diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index bb0eee1..640c01f 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -238,6 +238,20 @@ func (d *device) getClass() (Class, error) { return Class(device.Class), nil } +// isSkipped checks whether the device should be skipped. +func (d *device) isSkipped() (bool, error) { + name, ret := d.GetName() + if ret != nvml.SUCCESS { + return false, fmt.Errorf("error getting device name: %v", ret) + } + + if _, exists := d.lib.skippedDevices[name]; exists { + return true, nil + } + + return false, nil +} + // VisitDevices visits each top-level device and invokes a callback function for it func (d *devicelib) VisitDevices(visit func(int, Device) error) error { count, ret := d.nvml.DeviceGetCount() @@ -255,11 +269,11 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error { return fmt.Errorf("error creating new device wrapper: %v", err) } - class, err := dev.getClass() + isSkipped, err := dev.isSkipped() if err != nil { - return fmt.Errorf("error getting PCI device class for device: %v", err) + return fmt.Errorf("error checking whether device is skipped: %v", err) } - if !d.classIsSelected(class) { + if isSkipped { continue }