diff --git a/pkg/nvlib/device/api.go b/pkg/nvlib/device/api.go index c6605fc..9886d7b 100644 --- a/pkg/nvlib/device/api.go +++ b/pkg/nvlib/device/api.go @@ -40,6 +40,7 @@ type devicelib struct { nvml nvml.Interface skippedDevices map[string]struct{} verifySymbols *bool + migProfiles []MigProfile } var _ Interface = &devicelib{} diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index 40325dc..e2603cf 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -212,7 +212,10 @@ func (d *device) VisitMigProfiles(visit func(MigProfile) error) error { // physically constructed. In the future we should do this via // NVML once a proper API for this exists. pi := p.GetInfo() - if (pi.C * 2) > (pi.G + 1) { + if pi.C > pi.G { + continue + } + if (pi.C < pi.G) && ((pi.C * 2) > (pi.G + 1)) { continue } @@ -385,6 +388,12 @@ func (d *devicelib) GetMigDevices() ([]MigDevice, error) { // GetMigProfiles gets the set of unique MIG profiles across all top-level devices func (d *devicelib) GetMigProfiles() ([]MigProfile, error) { + // Return the cached list if available + if d.migProfiles != nil { + return d.migProfiles, nil + } + + // Otherwise generate it... var profiles []MigProfile err := d.VisitMigProfiles(func(p MigProfile) error { profiles = append(profiles, p) @@ -393,6 +402,9 @@ func (d *devicelib) GetMigProfiles() ([]MigProfile, error) { if err != nil { return nil, err } + + // And cache it before returning + d.migProfiles = profiles return profiles, nil }