diff --git a/.golangci.yml b/.golangci.yml index e0fc3d2..1555093 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -3,6 +3,7 @@ linters: enable: - asciicheck - contextcheck + - gocritic - godot - gofmt - goimports @@ -10,7 +11,6 @@ linters: # TODO: re-enable once we have addressed the warnings disable: - unused - - gocritic - stylecheck - forcetypeassert diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index d20b69a..024ca5c 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -125,7 +125,7 @@ func (d *device) GetBrandAsString() (string, error) { case nvml.BRAND_NVIDIA_VWS: return "NvidiaVWS", nil // Deprecated in favor of nvml.BRAND_NVIDIA_CLOUD_GAMING - //case nvml.BRAND_NVIDIA_VGAMING: + // case nvml.BRAND_NVIDIA_VGAMING: // return "VGaming", nil case nvml.BRAND_NVIDIA_CLOUD_GAMING: return "NvidiaCloudGaming", nil diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index 6ff197b..5caf034 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -280,27 +280,14 @@ func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevi return nil, fmt.Errorf("unable to convert device string to uint16: %v", deviceStr) } - driver, err := filepath.EvalSymlinks(path.Join(devicePath, "driver")) - if err == nil { - driver = filepath.Base(driver) - } else if os.IsNotExist(err) { - driver = "" - } else { - return nil, fmt.Errorf("unable to detect driver for %s: %v", address, err) + driver, err := getDriver(devicePath) + if err != nil { + return nil, fmt.Errorf("unable to detect driver for %s: %w", address, err) } - var iommuGroup int64 - iommu, err := filepath.EvalSymlinks(path.Join(devicePath, "iommu_group")) - if err == nil { - iommuGroupStr := strings.TrimSpace(filepath.Base(iommu)) - iommuGroup, err = strconv.ParseInt(iommuGroupStr, 0, 64) - if err != nil { - return nil, fmt.Errorf("unable to convert iommu_group string to int64: %v", iommuGroupStr) - } - } else if os.IsNotExist(err) { - iommuGroup = -1 - } else { - return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err) + iommuGroup, err := getIOMMUGroup(devicePath) + if err != nil { + return nil, fmt.Errorf("unable to detect IOMMU group for %s: %w", address, err) } numa, err := os.ReadFile(path.Join(devicePath, "numa_node")) @@ -359,7 +346,8 @@ func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevi var sriovInfo SriovInfo // Device is a virtual function (VF) if "physfn" symlink exists. physFnAddress, err := filepath.EvalSymlinks(path.Join(devicePath, "physfn")) - if err == nil { + switch { + case err == nil: physFn, err := p.getGPUByPciBusID(filepath.Base(physFnAddress), cache) if err != nil { return nil, fmt.Errorf("unable to detect physfn for %s: %v", address, err) @@ -369,12 +357,12 @@ func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevi PhysicalFunction: physFn, }, } - } else if os.IsNotExist(err) { + case os.IsNotExist(err): sriovInfo, err = p.getSriovInfoForPhysicalFunction(devicePath) if err != nil { return nil, fmt.Errorf("unable to read SRIOV physical function details for %s: %v", devicePath, err) } - } else { + default: return nil, fmt.Errorf("unable to read %s: %v", path.Join(devicePath, "physfn"), err) } @@ -521,3 +509,31 @@ func (p *nvpci) getSriovInfoForPhysicalFunction(devicePath string) (sriovInfo Sr } return sriovInfo, nil } + +func getDriver(devicePath string) (string, error) { + driver, err := filepath.EvalSymlinks(path.Join(devicePath, "driver")) + switch { + case os.IsNotExist(err): + return "", nil + case err == nil: + return filepath.Base(driver), nil + } + return "", err +} + +func getIOMMUGroup(devicePath string) (int64, error) { + var iommuGroup int64 + iommu, err := filepath.EvalSymlinks(path.Join(devicePath, "iommu_group")) + switch { + case os.IsNotExist(err): + return -1, nil + case err == nil: + iommuGroupStr := strings.TrimSpace(filepath.Base(iommu)) + iommuGroup, err = strconv.ParseInt(iommuGroupStr, 0, 64) + if err != nil { + return 0, fmt.Errorf("unable to convert iommu_group string to int64: %v", iommuGroupStr) + } + return iommuGroup, nil + } + return 0, err +} diff --git a/pkg/nvpci/resources.go b/pkg/nvpci/resources.go index b3b7d31..3d799ae 100644 --- a/pkg/nvpci/resources.go +++ b/pkg/nvpci/resources.go @@ -112,7 +112,7 @@ func (mrs MemoryResources) GetTotalAddressableMemory(roundUp bool) (uint64, uint if key >= pciIOVNumBAR || numBAR == pciIOVNumBAR { break } - numBAR = numBAR + 1 + numBAR++ region := mrs[key] @@ -123,10 +123,10 @@ func (mrs MemoryResources) GetTotalAddressableMemory(roundUp bool) (uint64, uint memSize := (region.End - region.Start) + 1 if memType32bit { - memSize32bit = memSize32bit + uint64(memSize) + memSize32bit += uint64(memSize) } if memType64bit { - memSize64bit = memSize64bit + uint64(memSize) + memSize64bit += uint64(memSize) } } diff --git a/pkg/pciids/pciids.go b/pkg/pciids/pciids.go index 343df08..79cb227 100644 --- a/pkg/pciids/pciids.go +++ b/pkg/pciids/pciids.go @@ -396,7 +396,7 @@ func (p *parser) parse() Interface { hkClass = db.classes[uint32(id)] hkFullID = uint32(id) << 16 - hkFullID = hkFullID & 0xFFFF0000 + hkFullID &= 0xFFFF0000 hkFullName[0] = fmt.Sprintf("%s (%02x)", lit.name, id) } @@ -408,11 +408,11 @@ func (p *parser) parse() Interface { } hkSubClass = hkClass.subClasses[uint32(id)] - // Clear the last detected sub class. - hkFullID = hkFullID & 0xFFFF0000 - hkFullID = hkFullID | uint32(id)<<8 + // Clear the last detected subclass. + hkFullID &= 0xFFFF0000 + hkFullID |= uint32(id) << 8 // Clear the last detected prog iface. - hkFullID = hkFullID & 0xFFFFFF00 + hkFullID &= 0xFFFFFF00 hkFullName[1] = fmt.Sprintf("%s (%02x)", lit.name, id) db.classes[uint32(hkFullID)] = class{