From 655eb9795c5438eb44626aba910b0fa79ba83bf4 Mon Sep 17 00:00:00 2001
From: Evan Lezar <elezar@nvidia.com>
Date: Wed, 16 Nov 2022 16:07:58 +0100
Subject: [PATCH] Skip display devices based on device names

This allows devices to be skipped based on device names and
skips "NVIDIA DGX Display" devices by default.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
---
 pkg/nvlib/device/api.go    | 22 ++++++++++------------
 pkg/nvlib/device/device.go | 20 +++++++++++++++++---
 2 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/pkg/nvlib/device/api.go b/pkg/nvlib/device/api.go
index 498bda8..7741915 100644
--- a/pkg/nvlib/device/api.go
+++ b/pkg/nvlib/device/api.go
@@ -35,8 +35,8 @@ type Interface interface {
 }
 
 type devicelib struct {
-	nvml                  nvml.Interface
-	selectedDeviceClasses map[Class]struct{}
+	nvml           nvml.Interface
+	skippedDevices map[string]struct{}
 }
 
 var _ Interface = &devicelib{}
@@ -50,10 +50,8 @@ func New(opts ...Option) Interface {
 	if d.nvml == nil {
 		d.nvml = nvml.New()
 	}
-	if d.selectedDeviceClasses == nil {
-		d.selectedDeviceClasses = map[Class]struct{}{
-			ClassCompute: {},
-		}
+	if d.skippedDevices == nil {
+		WithSkippedDevices("NVIDIA DGX Display")(d)
 	}
 	return d
 }
@@ -65,14 +63,14 @@ func WithNvml(nvml nvml.Interface) Option {
 	}
 }
 
-// WithSelectedDeviceClasses selects the specified device classes when filtering devices
-func WithSelectedDeviceClasses(classes ...Class) Option {
+// WithSkippedDevices provides an Option to set devices to be skipped by model name
+func WithSkippedDevices(names ...string) Option {
 	return func(d *devicelib) {
-		if d.selectedDeviceClasses == nil {
-			d.selectedDeviceClasses = make(map[Class]struct{})
+		if d.skippedDevices == nil {
+			d.skippedDevices = make(map[string]struct{})
 		}
-		for _, c := range classes {
-			d.selectedDeviceClasses[c] = struct{}{}
+		for _, name := range names {
+			d.skippedDevices[name] = struct{}{}
 		}
 	}
 }
diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go
index bb0eee1..640c01f 100644
--- a/pkg/nvlib/device/device.go
+++ b/pkg/nvlib/device/device.go
@@ -238,6 +238,20 @@ func (d *device) getClass() (Class, error) {
 	return Class(device.Class), nil
 }
 
+// isSkipped checks whether the device should be skipped.
+func (d *device) isSkipped() (bool, error) {
+	name, ret := d.GetName()
+	if ret != nvml.SUCCESS {
+		return false, fmt.Errorf("error getting device name: %v", ret)
+	}
+
+	if _, exists := d.lib.skippedDevices[name]; exists {
+		return true, nil
+	}
+
+	return false, nil
+}
+
 // VisitDevices visits each top-level device and invokes a callback function for it
 func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
 	count, ret := d.nvml.DeviceGetCount()
@@ -255,11 +269,11 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
 			return fmt.Errorf("error creating new device wrapper: %v", err)
 		}
 
-		class, err := dev.getClass()
+		isSkipped, err := dev.isSkipped()
 		if err != nil {
-			return fmt.Errorf("error getting PCI device class for device: %v", err)
+			return fmt.Errorf("error checking whether device is skipped: %v", err)
 		}
-		if !d.classIsSelected(class) {
+		if isSkipped {
 			continue
 		}