mirror of
https://github.com/clearml/go-nvlib
synced 2025-01-31 02:47:02 +00:00
Skip display devices based on device names
This allows devices to be skipped based on device names and skips "NVIDIA DGX Display" devices by default. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
0e10f084d1
commit
655eb9795c
@ -36,7 +36,7 @@ type Interface interface {
|
|||||||
|
|
||||||
type devicelib struct {
|
type devicelib struct {
|
||||||
nvml nvml.Interface
|
nvml nvml.Interface
|
||||||
selectedDeviceClasses map[Class]struct{}
|
skippedDevices map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Interface = &devicelib{}
|
var _ Interface = &devicelib{}
|
||||||
@ -50,10 +50,8 @@ func New(opts ...Option) Interface {
|
|||||||
if d.nvml == nil {
|
if d.nvml == nil {
|
||||||
d.nvml = nvml.New()
|
d.nvml = nvml.New()
|
||||||
}
|
}
|
||||||
if d.selectedDeviceClasses == nil {
|
if d.skippedDevices == nil {
|
||||||
d.selectedDeviceClasses = map[Class]struct{}{
|
WithSkippedDevices("NVIDIA DGX Display")(d)
|
||||||
ClassCompute: {},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
@ -65,14 +63,14 @@ func WithNvml(nvml nvml.Interface) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSelectedDeviceClasses selects the specified device classes when filtering devices
|
// WithSkippedDevices provides an Option to set devices to be skipped by model name
|
||||||
func WithSelectedDeviceClasses(classes ...Class) Option {
|
func WithSkippedDevices(names ...string) Option {
|
||||||
return func(d *devicelib) {
|
return func(d *devicelib) {
|
||||||
if d.selectedDeviceClasses == nil {
|
if d.skippedDevices == nil {
|
||||||
d.selectedDeviceClasses = make(map[Class]struct{})
|
d.skippedDevices = make(map[string]struct{})
|
||||||
}
|
}
|
||||||
for _, c := range classes {
|
for _, name := range names {
|
||||||
d.selectedDeviceClasses[c] = struct{}{}
|
d.skippedDevices[name] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -238,6 +238,20 @@ func (d *device) getClass() (Class, error) {
|
|||||||
return Class(device.Class), nil
|
return Class(device.Class), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isSkipped checks whether the device should be skipped.
|
||||||
|
func (d *device) isSkipped() (bool, error) {
|
||||||
|
name, ret := d.GetName()
|
||||||
|
if ret != nvml.SUCCESS {
|
||||||
|
return false, fmt.Errorf("error getting device name: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := d.lib.skippedDevices[name]; exists {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// VisitDevices visits each top-level device and invokes a callback function for it
|
// VisitDevices visits each top-level device and invokes a callback function for it
|
||||||
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
||||||
count, ret := d.nvml.DeviceGetCount()
|
count, ret := d.nvml.DeviceGetCount()
|
||||||
@ -255,11 +269,11 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
|||||||
return fmt.Errorf("error creating new device wrapper: %v", err)
|
return fmt.Errorf("error creating new device wrapper: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
class, err := dev.getClass()
|
isSkipped, err := dev.isSkipped()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting PCI device class for device: %v", err)
|
return fmt.Errorf("error checking whether device is skipped: %v", err)
|
||||||
}
|
}
|
||||||
if !d.classIsSelected(class) {
|
if isSkipped {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user