Merge branch 'skip-display-devices-on-name' into 'main'

Skip display devices based on model name

See merge request nvidia/cloud-native/go-nvlib!26
This commit is contained in:
Evan Lezar 2022-11-21 20:39:40 +00:00
commit a27e593595
4 changed files with 33 additions and 59 deletions

View File

@ -96,3 +96,15 @@ $(DOCKER_TARGETS): docker-%: .build-image
--user $$(id -u):$$(id -g) \ --user $$(id -u):$$(id -g) \
$(BUILDIMAGE) \ $(BUILDIMAGE) \
make $(*) make $(*)
# Start an interactive shell using the development image.
PHONY: .shell
.shell:
$(DOCKER) run \
--rm \
-ti \
-e GOCACHE=/tmp/.cache \
-v $(PWD):$(PWD) \
-w $(PWD) \
--user $$(id -u):$$(id -g) \
$(BUILDIMAGE)

View File

@ -15,4 +15,4 @@ ARG GOLANG_VERSION=1.16
FROM golang:${GOLANG_VERSION} FROM golang:${GOLANG_VERSION}
RUN go get -u golang.org/x/lint/golint RUN go get -u golang.org/x/lint/golint
RUN go install github.com/matryer/moq@latest RUN go install github.com/matryer/moq@v0.2.7

View File

@ -36,7 +36,7 @@ type Interface interface {
type devicelib struct { type devicelib struct {
nvml nvml.Interface nvml nvml.Interface
selectedDeviceClasses map[Class]struct{} skippedDevices map[string]struct{}
} }
var _ Interface = &devicelib{} var _ Interface = &devicelib{}
@ -50,10 +50,8 @@ func New(opts ...Option) Interface {
if d.nvml == nil { if d.nvml == nil {
d.nvml = nvml.New() d.nvml = nvml.New()
} }
if d.selectedDeviceClasses == nil { if d.skippedDevices == nil {
d.selectedDeviceClasses = map[Class]struct{}{ WithSkippedDevices("NVIDIA DGX Display")(d)
ClassCompute: {},
}
} }
return d return d
} }
@ -65,14 +63,14 @@ func WithNvml(nvml nvml.Interface) Option {
} }
} }
// WithSelectedDeviceClasses selects the specified device classes when filtering devices // WithSkippedDevices provides an Option to set devices to be skipped by model name
func WithSelectedDeviceClasses(classes ...Class) Option { func WithSkippedDevices(names ...string) Option {
return func(d *devicelib) { return func(d *devicelib) {
if d.selectedDeviceClasses == nil { if d.skippedDevices == nil {
d.selectedDeviceClasses = make(map[Class]struct{}) d.skippedDevices = make(map[string]struct{})
} }
for _, c := range classes { for _, name := range names {
d.selectedDeviceClasses[c] = struct{}{} d.skippedDevices[name] = struct{}{}
} }
} }
} }

View File

@ -18,11 +18,9 @@ package device
import ( import (
"fmt" "fmt"
"strings"
"github.com/NVIDIA/go-nvml/pkg/dl" "github.com/NVIDIA/go-nvml/pkg/dl"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
) )
// Device defines the set of extended functions associated with a device.Device // Device defines the set of extended functions associated with a device.Device
@ -41,15 +39,6 @@ type device struct {
lib *devicelib lib *devicelib
} }
// Class represents the PCI class for a device
type Class uint32
// Define constants for common device classes
const (
ClassCompute = Class(nvpci.PCI3dControllerClass)
ClassDisplay = Class(nvpci.PCIVgaControllerClass)
)
var _ Device = &device{} var _ Device = &device{}
// NewDevice builds a new Device from an nvml.Device // NewDevice builds a new Device from an nvml.Device
@ -62,16 +51,6 @@ func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
return &device{dev, d}, nil return &device{dev, d}, nil
} }
// classIsSelected checks whether the specified class has been selected when constructing the devicelib
func (d *devicelib) classIsSelected(c Class) bool {
if d.selectedDeviceClasses == nil {
return false
}
_, exists := d.selectedDeviceClasses[c]
return exists
}
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it // IsMigCapable checks if a device is capable of having MIG paprtitions created on it
func (d *device) IsMigCapable() (bool, error) { func (d *device) IsMigCapable() (bool, error) {
err := nvmlLookupSymbol("nvmlDeviceGetMigMode") err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
@ -209,33 +188,18 @@ func (d *device) GetMigProfiles() ([]MigProfile, error) {
return profiles, nil return profiles, nil
} }
// getClass returns the PCI device class for the device // isSkipped checks whether the device should be skipped.
func (d *device) getClass() (Class, error) { func (d *device) isSkipped() (bool, error) {
info, ret := d.GetPciInfo() name, ret := d.GetName()
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return 0, fmt.Errorf("failed to get PCI info: %v", ret) return false, fmt.Errorf("error getting device name: %v", ret)
} }
// We convert the BusId to a string if _, exists := d.lib.skippedDevices[name]; exists {
var bytes []byte return true, nil
for _, b := range info.BusId {
if byte(b) == '\x00' {
break
}
bytes = append(bytes, byte(b))
}
id := strings.ToLower(string(bytes))
if id != "0000" {
id = strings.TrimPrefix(id, "0000")
} }
device, err := nvpci.New().GetGPUByPciBusID(id) return false, nil
if err != nil {
return 0, fmt.Errorf("failed to construct PCI device: %v", err)
}
return Class(device.Class), nil
} }
// VisitDevices visits each top-level device and invokes a callback function for it // VisitDevices visits each top-level device and invokes a callback function for it
@ -255,11 +219,11 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
return fmt.Errorf("error creating new device wrapper: %v", err) return fmt.Errorf("error creating new device wrapper: %v", err)
} }
class, err := dev.getClass() isSkipped, err := dev.isSkipped()
if err != nil { if err != nil {
return fmt.Errorf("error getting PCI device class for device: %v", err) return fmt.Errorf("error checking whether device is skipped: %v", err)
} }
if !d.classIsSelected(class) { if isSkipped {
continue continue
} }