From 1b3ef9bd641e2d8a5b6154c65cb08e681a1326b8 Mon Sep 17 00:00:00 2001 From: Christopher Desiniotis Date: Fri, 9 Jun 2023 16:56:08 -0700 Subject: [PATCH] Update pciids interface to return errors for invalid vendor / device ids Signed-off-by: Christopher Desiniotis --- pkg/nvpci/nvpci.go | 13 +++++++++++-- pkg/pciids/pciids.go | 26 ++++++++++++++++++++------ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index 1640590..6227b0c 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -302,6 +302,15 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { pciDB := pciids.NewDB() + deviceName, err := pciDB.GetDeviceName(uint16(vendorID), uint16(deviceID)) + if err != nil { + return nil, fmt.Errorf("unable to get device name: %v", err) + } + className, err := pciDB.GetClassName(uint32(classID)) + if err != nil { + return nil, fmt.Errorf("unable to get class name for device: %v", err) + } + nvdevice := &NvidiaPCIDevice{ Path: devicePath, Address: address, @@ -314,8 +323,8 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { Config: config, Resources: resources, IsVF: isVF, - DeviceName: pciDB.GetDeviceName(uint16(vendorID), uint16(deviceID)), - ClassName: pciDB.GetClassName(uint32(classID)), + DeviceName: deviceName, + ClassName: className, } return nvdevice, nil diff --git a/pkg/pciids/pciids.go b/pkg/pciids/pciids.go index 7a446ab..f3b7f85 100644 --- a/pkg/pciids/pciids.go +++ b/pkg/pciids/pciids.go @@ -253,18 +253,32 @@ var _ Interface = (*pcidb)(nil) // Interface returns textual description of specific attributes of PCI devices type Interface interface { - GetDeviceName(uint16, uint16) string - GetClassName(uint32) string + GetDeviceName(uint16, uint16) (string, error) + GetClassName(uint32) (string, error) } // GetDeviceName return the textual description of the PCI device -func (d *pcidb) GetDeviceName(vendorID uint16, deviceID uint16) string { - return d.vendors[vendorID].devices[deviceID].name +func (d *pcidb) GetDeviceName(vendorID uint16, deviceID uint16) (string, error) { + vendor, ok := d.vendors[vendorID] + if !ok { + return "", fmt.Errorf("failed to find vendor with id '%x'", vendorID) + } + + device, ok := vendor.devices[deviceID] + if !ok { + return "", fmt.Errorf("failed to find device with id '%x'", deviceID) + } + + return device.name, nil } // GetClassName resturn the textual description of the PCI device class -func (d *pcidb) GetClassName(classID uint32) string { - return d.classes[classID].name +func (d *pcidb) GetClassName(classID uint32) (string, error) { + class, ok := d.classes[classID] + if !ok { + return "", fmt.Errorf("failed to find class with id '%x'", classID) + } + return class.name, nil } // pcidb The complete set of PCI vendors and PCI classes