mirror of
https://github.com/clearml/go-nvlib
synced 2025-05-13 08:10:39 +00:00
Merge branch 'skip-display-devices-on-name' into 'main'
Skip display devices based on model name See merge request nvidia/cloud-native/go-nvlib!26
This commit is contained in:
commit
a27e593595
12
Makefile
12
Makefile
@ -96,3 +96,15 @@ $(DOCKER_TARGETS): docker-%: .build-image
|
|||||||
--user $$(id -u):$$(id -g) \
|
--user $$(id -u):$$(id -g) \
|
||||||
$(BUILDIMAGE) \
|
$(BUILDIMAGE) \
|
||||||
make $(*)
|
make $(*)
|
||||||
|
|
||||||
|
# Start an interactive shell using the development image.
|
||||||
|
PHONY: .shell
|
||||||
|
.shell:
|
||||||
|
$(DOCKER) run \
|
||||||
|
--rm \
|
||||||
|
-ti \
|
||||||
|
-e GOCACHE=/tmp/.cache \
|
||||||
|
-v $(PWD):$(PWD) \
|
||||||
|
-w $(PWD) \
|
||||||
|
--user $$(id -u):$$(id -g) \
|
||||||
|
$(BUILDIMAGE)
|
||||||
|
@ -15,4 +15,4 @@ ARG GOLANG_VERSION=1.16
|
|||||||
FROM golang:${GOLANG_VERSION}
|
FROM golang:${GOLANG_VERSION}
|
||||||
|
|
||||||
RUN go get -u golang.org/x/lint/golint
|
RUN go get -u golang.org/x/lint/golint
|
||||||
RUN go install github.com/matryer/moq@latest
|
RUN go install github.com/matryer/moq@v0.2.7
|
||||||
|
@ -36,7 +36,7 @@ type Interface interface {
|
|||||||
|
|
||||||
type devicelib struct {
|
type devicelib struct {
|
||||||
nvml nvml.Interface
|
nvml nvml.Interface
|
||||||
selectedDeviceClasses map[Class]struct{}
|
skippedDevices map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Interface = &devicelib{}
|
var _ Interface = &devicelib{}
|
||||||
@ -50,10 +50,8 @@ func New(opts ...Option) Interface {
|
|||||||
if d.nvml == nil {
|
if d.nvml == nil {
|
||||||
d.nvml = nvml.New()
|
d.nvml = nvml.New()
|
||||||
}
|
}
|
||||||
if d.selectedDeviceClasses == nil {
|
if d.skippedDevices == nil {
|
||||||
d.selectedDeviceClasses = map[Class]struct{}{
|
WithSkippedDevices("NVIDIA DGX Display")(d)
|
||||||
ClassCompute: {},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
@ -65,14 +63,14 @@ func WithNvml(nvml nvml.Interface) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSelectedDeviceClasses selects the specified device classes when filtering devices
|
// WithSkippedDevices provides an Option to set devices to be skipped by model name
|
||||||
func WithSelectedDeviceClasses(classes ...Class) Option {
|
func WithSkippedDevices(names ...string) Option {
|
||||||
return func(d *devicelib) {
|
return func(d *devicelib) {
|
||||||
if d.selectedDeviceClasses == nil {
|
if d.skippedDevices == nil {
|
||||||
d.selectedDeviceClasses = make(map[Class]struct{})
|
d.skippedDevices = make(map[string]struct{})
|
||||||
}
|
}
|
||||||
for _, c := range classes {
|
for _, name := range names {
|
||||||
d.selectedDeviceClasses[c] = struct{}{}
|
d.skippedDevices[name] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,11 +18,9 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/NVIDIA/go-nvml/pkg/dl"
|
"github.com/NVIDIA/go-nvml/pkg/dl"
|
||||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Device defines the set of extended functions associated with a device.Device
|
// Device defines the set of extended functions associated with a device.Device
|
||||||
@ -41,15 +39,6 @@ type device struct {
|
|||||||
lib *devicelib
|
lib *devicelib
|
||||||
}
|
}
|
||||||
|
|
||||||
// Class represents the PCI class for a device
|
|
||||||
type Class uint32
|
|
||||||
|
|
||||||
// Define constants for common device classes
|
|
||||||
const (
|
|
||||||
ClassCompute = Class(nvpci.PCI3dControllerClass)
|
|
||||||
ClassDisplay = Class(nvpci.PCIVgaControllerClass)
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ Device = &device{}
|
var _ Device = &device{}
|
||||||
|
|
||||||
// NewDevice builds a new Device from an nvml.Device
|
// NewDevice builds a new Device from an nvml.Device
|
||||||
@ -62,16 +51,6 @@ func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
|
|||||||
return &device{dev, d}, nil
|
return &device{dev, d}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// classIsSelected checks whether the specified class has been selected when constructing the devicelib
|
|
||||||
func (d *devicelib) classIsSelected(c Class) bool {
|
|
||||||
if d.selectedDeviceClasses == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
_, exists := d.selectedDeviceClasses[c]
|
|
||||||
|
|
||||||
return exists
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
|
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
|
||||||
func (d *device) IsMigCapable() (bool, error) {
|
func (d *device) IsMigCapable() (bool, error) {
|
||||||
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
|
err := nvmlLookupSymbol("nvmlDeviceGetMigMode")
|
||||||
@ -209,33 +188,18 @@ func (d *device) GetMigProfiles() ([]MigProfile, error) {
|
|||||||
return profiles, nil
|
return profiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClass returns the PCI device class for the device
|
// isSkipped checks whether the device should be skipped.
|
||||||
func (d *device) getClass() (Class, error) {
|
func (d *device) isSkipped() (bool, error) {
|
||||||
info, ret := d.GetPciInfo()
|
name, ret := d.GetName()
|
||||||
if ret != nvml.SUCCESS {
|
if ret != nvml.SUCCESS {
|
||||||
return 0, fmt.Errorf("failed to get PCI info: %v", ret)
|
return false, fmt.Errorf("error getting device name: %v", ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We convert the BusId to a string
|
if _, exists := d.lib.skippedDevices[name]; exists {
|
||||||
var bytes []byte
|
return true, nil
|
||||||
for _, b := range info.BusId {
|
|
||||||
if byte(b) == '\x00' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
bytes = append(bytes, byte(b))
|
|
||||||
}
|
|
||||||
id := strings.ToLower(string(bytes))
|
|
||||||
|
|
||||||
if id != "0000" {
|
|
||||||
id = strings.TrimPrefix(id, "0000")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := nvpci.New().GetGPUByPciBusID(id)
|
return false, nil
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to construct PCI device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return Class(device.Class), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VisitDevices visits each top-level device and invokes a callback function for it
|
// VisitDevices visits each top-level device and invokes a callback function for it
|
||||||
@ -255,11 +219,11 @@ func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
|
|||||||
return fmt.Errorf("error creating new device wrapper: %v", err)
|
return fmt.Errorf("error creating new device wrapper: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
class, err := dev.getClass()
|
isSkipped, err := dev.isSkipped()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting PCI device class for device: %v", err)
|
return fmt.Errorf("error checking whether device is skipped: %v", err)
|
||||||
}
|
}
|
||||||
if !d.classIsSelected(class) {
|
if isSkipped {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user