Skip display devices based on device names

This allows devices to be skipped based on device names and
skips "NVIDIA DGX Display" devices by default.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2022-11-16 16:07:58 +01:00
parent 0e10f084d1
commit 655eb9795c
2 changed files with 27 additions and 15 deletions

View File

@ -36,7 +36,7 @@ type Interface interface {
type devicelib struct { type devicelib struct {
nvml nvml.Interface nvml nvml.Interface
selectedDeviceClasses map[Class]struct{} skippedDevices map[string]struct{}
} }
var _ Interface = &devicelib{} var _ Interface = &devicelib{}
@ -50,10 +50,8 @@ func New(opts ...Option) Interface {
if d.nvml == nil { if d.nvml == nil {
d.nvml = nvml.New() d.nvml = nvml.New()
} }
if d.selectedDeviceClasses == nil { if d.skippedDevices == nil {
d.selectedDeviceClasses = map[Class]struct{}{ WithSkippedDevices("NVIDIA DGX Display")(d)
ClassCompute: {},
}
} }
return d return d
} }
@ -65,14 +63,14 @@ func WithNvml(nvml nvml.Interface) Option {
} }
} }
// WithSelectedDeviceClasses selects the specified device classes when filtering devices // WithSkippedDevices provides an Option to set devices to be skipped by model name
func WithSelectedDeviceClasses(classes ...Class) Option { func WithSkippedDevices(names ...string) Option {
return func(d *devicelib) { return func(d *devicelib) {
if d.selectedDeviceClasses == nil { if d.skippedDevices == nil {
d.selectedDeviceClasses = make(map[Class]struct{}) d.skippedDevices = make(map[string]struct{})
} }
for _, c := range classes { for _, name := range names {
d.selectedDeviceClasses[c] = struct{}{} d.skippedDevices[name] = struct{}{}
} }
} }
} }

View File

@ -238,6 +238,20 @@ func (d *device) getClass() (Class, error) {
return Class(device.Class), nil return Class(device.Class), nil
} }
// isSkipped checks whether the device should be skipped.
func (d *device) isSkipped() (bool, error) {
name, ret := d.GetName()
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting device name: %v", ret)
}
if _, exists := d.lib.skippedDevices[name]; exists {
return true, nil
}
return false, nil
}
// VisitDevices visits each top-level device and invokes a callback function for it // VisitDevices visits each top-level device and invokes a callback function for it
func (d *devicelib) VisitDevices(visit func(int, Device) error) error { func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount() count, ret := d.nvml.DeviceGetCount()
@ -255,11 +269,11 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
return fmt.Errorf("error creating new device wrapper: %v", err) return fmt.Errorf("error creating new device wrapper: %v", err)
} }
class, err := dev.getClass() isSkipped, err := dev.isSkipped()
if err != nil { if err != nil {
return fmt.Errorf("error getting PCI device class for device: %v", err) return fmt.Errorf("error checking whether device is skipped: %v", err)
} }
if !d.classIsSelected(class) { if isSkipped {
continue continue
} }