diff --git a/pkg/nvlib/mig/mig.go b/pkg/nvlib/device/api.go similarity index 54% rename from pkg/nvlib/mig/mig.go rename to pkg/nvlib/device/api.go index 041bbab..bcc37eb 100644 --- a/pkg/nvlib/mig/mig.go +++ b/pkg/nvlib/device/api.go @@ -14,43 +14,50 @@ * limitations under the License. */ -package mig +package device import ( "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) -// Interface provides the API to the mig package +// Interface provides the API to the 'device' package type Interface interface { - NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) - ParseProfile(profile string) (Profile, error) + GetDevices() ([]Device, error) + GetMigDevices() ([]MigDevice, error) + GetMigProfiles() ([]MigProfile, error) NewDevice(d nvml.Device) (Device, error) + NewMigDevice(d nvml.Device) (MigDevice, error) + NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) + ParseMigProfile(profile string) (MigProfile, error) + VisitDevices(func(i int, d Device) error) error + VisitMigDevices(func(i int, d Device, j int, m MigDevice) error) error + VisitMigProfiles(func(p MigProfile) error) error } -type miglib struct { +type devicelib struct { nvml nvml.Interface } -var _ Interface = &miglib{} +var _ Interface = &devicelib{} -// New creates a new instance of the 'mig' interface +// New creates a new instance of the 'device' interface func New(opts ...Option) Interface { - m := &miglib{} + d := &devicelib{} for _, opt := range opts { - opt(m) + opt(d) } - if m.nvml == nil { - m.nvml = nvml.New() + if d.nvml == nil { + d.nvml = nvml.New() } - return m + return d } -// WithNvml provides an Option to set the NVML library used by the 'mig' interface +// WithNvml provides an Option to set the NVML library used by the 'device' interface func WithNvml(nvml nvml.Interface) Option { - return func(m *miglib) { - m.nvml = nvml + return func(d *devicelib) { + d.nvml = nvml } } // Option defines a function for passing options to the New() call -type Option func(*miglib) +type Option func(*devicelib) diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go new file mode 100644 index 0000000..c10cfcc --- /dev/null +++ b/pkg/nvlib/device/device.go @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package device + +import ( + "fmt" + + "github.com/NVIDIA/go-nvml/pkg/dl" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" +) + +// Device defines the set of extended functions associated with a device.Device +type Device interface { + nvml.Device + GetMigDevices() ([]MigDevice, error) + GetMigProfiles() ([]MigProfile, error) + IsMigCapable() (bool, error) + IsMigEnabled() (bool, error) + VisitMigDevices(func(j int, m MigDevice) error) error + VisitMigProfiles(func(p MigProfile) error) error +} + +type device struct { + nvml.Device + lib *devicelib +} + +var _ Device = &device{} + +// NewDevice builds a new Device from an nvml.Device +func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) { + return &device{dev, d}, nil +} + +// IsMigCapable checks if a device is capable of having MIG paprtitions created on it +func (d *device) IsMigCapable() (bool, error) { + err := nvmlLookupSymbol("nvmlDeviceGetMigMode") + if err != nil { + return false, nil + } + + _, _, ret := nvml.Device(d).GetMigMode() + if ret == nvml.ERROR_NOT_SUPPORTED { + return false, nil + } + if ret != nvml.SUCCESS { + return false, fmt.Errorf("error getting MIG mode: %v", ret) + } + + return true, nil +} + +// IsMigEnabled checks if a device has MIG mode currently enabled on it +func (d *device) IsMigEnabled() (bool, error) { + err := nvmlLookupSymbol("nvmlDeviceGetMigMode") + if err != nil { + return false, nil + } + + mode, _, ret := nvml.Device(d).GetMigMode() + if ret == nvml.ERROR_NOT_SUPPORTED { + return false, nil + } + if ret != nvml.SUCCESS { + return false, fmt.Errorf("error getting MIG mode: %v", ret) + } + + return (mode == nvml.DEVICE_MIG_ENABLE), nil +} + +// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it +func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error { + count, ret := nvml.Device(d).GetMaxMigDeviceCount() + if ret != nvml.SUCCESS { + return fmt.Errorf("error getting max MIG device count: %v", ret) + } + + for i := 0; i < count; i++ { + device, ret := nvml.Device(d).GetMigDeviceHandleByIndex(i) + if ret == nvml.ERROR_NOT_FOUND { + continue + } + if ret == nvml.ERROR_INVALID_ARGUMENT { + continue + } + if ret != nvml.SUCCESS { + return fmt.Errorf("error getting MIG device handle at index '%v': %v", i, ret) + } + mig, err := d.lib.NewMigDevice(device) + if err != nil { + return fmt.Errorf("error creating new MIG device wrapper: %v", err) + } + err = visit(i, mig) + if err != nil { + return fmt.Errorf("error visiting MIG device: %v", err) + } + } + return nil +} + +// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG Profile that can be configured on it +func (d *device) VisitMigProfiles(visit func(MigProfile) error) error { + capable, err := d.IsMigCapable() + if err != nil { + return fmt.Errorf("error checking if GPU is MIG capable: %v", err) + } + + if !capable { + return nil + } + + memory, ret := d.GetMemoryInfo() + if ret != nvml.SUCCESS { + return fmt.Errorf("error getting device memory info: %v", ret) + } + + for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ { + giProfileInfo, ret := d.GetGpuInstanceProfileInfo(i) + if ret != nvml.SUCCESS { + return fmt.Errorf("error getting GPU Instance profile info: %v", ret) + } + + for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ { + for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ { + p, err := d.lib.NewMigProfile(i, j, k, giProfileInfo.MemorySizeMB, memory.Total) + if err != nil { + return fmt.Errorf("error creating MIG profile: %v", err) + } + + err = visit(p) + if err != nil { + return fmt.Errorf("error visiting MIG profile: %v", err) + } + } + } + } + return nil +} + +// GetMigDevices gets the set of MIG devices associated with a top-level device +func (d *device) GetMigDevices() ([]MigDevice, error) { + var migs []MigDevice + err := d.VisitMigDevices(func(j int, m MigDevice) error { + migs = append(migs, m) + return nil + }) + if err != nil { + return nil, err + } + return migs, nil +} + +// GetMigProfiles gets the set of unique MIG profiles associated with a top-level device +func (d *device) GetMigProfiles() ([]MigProfile, error) { + var profiles []MigProfile + err := d.VisitMigProfiles(func(p MigProfile) error { + profiles = append(profiles, p) + return nil + }) + if err != nil { + return nil, err + } + return profiles, nil +} + +// VisitDevices visits each top-level device and invokes a callback function for it +func (d *devicelib) VisitDevices(visit func(int, Device) error) error { + count, ret := d.nvml.DeviceGetCount() + if ret != nvml.SUCCESS { + return fmt.Errorf("error getting device count: %v", ret) + } + + for i := 0; i < count; i++ { + device, ret := d.nvml.DeviceGetHandleByIndex(i) + if ret != nvml.SUCCESS { + return fmt.Errorf("error getting device handle for index '%v': %v", i, ret) + } + dev, err := d.NewDevice(device) + if err != nil { + return fmt.Errorf("error creating new device wrapper: %v", err) + } + err = visit(i, dev) + if err != nil { + return fmt.Errorf("error visiting device: %v", err) + } + } + return nil +} + +// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it +func (d *devicelib) VisitMigDevices(visit func(int, Device, int, MigDevice) error) error { + err := d.VisitDevices(func(i int, dev Device) error { + err := dev.VisitMigDevices(func(j int, mig MigDevice) error { + err := visit(i, dev, j, mig) + if err != nil { + return fmt.Errorf("error visiting MIG device: %v", err) + } + return nil + }) + if err != nil { + return fmt.Errorf("error visiting device: %v", err) + } + return nil + }) + if err != nil { + return fmt.Errorf("error visiting devices: %v", err) + } + return nil +} + +// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG profile found on them +func (d *devicelib) VisitMigProfiles(visit func(MigProfile) error) error { + visited := make(map[string]bool) + err := d.VisitDevices(func(i int, dev Device) error { + err := dev.VisitMigProfiles(func(p MigProfile) error { + if visited[p.String()] { + return nil + } + + err := visit(p) + if err != nil { + return fmt.Errorf("error visiting MIG profile: %v", err) + } + + visited[p.String()] = true + return nil + }) + if err != nil { + return fmt.Errorf("error visiting device: %v", err) + } + return nil + }) + if err != nil { + return fmt.Errorf("error visiting devices: %v", err) + } + return nil +} + +// GetDevices gets the set of all top-level devices +func (d *devicelib) GetDevices() ([]Device, error) { + var devs []Device + err := d.VisitDevices(func(i int, dev Device) error { + devs = append(devs, dev) + return nil + }) + if err != nil { + return nil, err + } + return devs, nil +} + +// GetMigDevices gets the set of MIG devices across all top-level devices +func (d *devicelib) GetMigDevices() ([]MigDevice, error) { + var migs []MigDevice + err := d.VisitMigDevices(func(i int, dev Device, j int, m MigDevice) error { + migs = append(migs, m) + return nil + }) + if err != nil { + return nil, err + } + return migs, nil +} + +// GetMigProfiles gets the set of unique MIG profiles across all top-level devices +func (d *devicelib) GetMigProfiles() ([]MigProfile, error) { + var profiles []MigProfile + err := d.VisitMigProfiles(func(p MigProfile) error { + profiles = append(profiles, p) + return nil + }) + if err != nil { + return nil, err + } + return profiles, nil +} + +// nvmlLookupSymbol checks to see if the given symbol is present in the NVML library +func nvmlLookupSymbol(symbol string) error { + lib := dl.New("libnvidia-ml.so.1", dl.RTLD_LAZY|dl.RTLD_GLOBAL) + if lib == nil { + return fmt.Errorf("error instantiating DynamicLibrary for NVML") + } + err := lib.Open() + if err != nil { + return fmt.Errorf("error opening DynamicLibrary for NVML: %v", err) + } + defer lib.Close() + return lib.Lookup(symbol) +} diff --git a/pkg/nvlib/mig/device.go b/pkg/nvlib/device/mig_device.go similarity index 77% rename from pkg/nvlib/mig/device.go rename to pkg/nvlib/device/mig_device.go index 4e429f5..0d87c98 100644 --- a/pkg/nvlib/mig/device.go +++ b/pkg/nvlib/device/mig_device.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package mig +package device import ( "fmt" @@ -22,38 +22,39 @@ import ( "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) -// Device defines the set of extended functions associated with a mig.Device -type Device interface { - GetProfile() (Profile, error) -} - -type device struct { +// MigDevice defines the set of extended functions associated with a MIG device +type MigDevice interface { nvml.Device - miglib *miglib - profile Profile + GetProfile() (MigProfile, error) } -var _ Device = &device{} +type migdevice struct { + nvml.Device + lib *devicelib + profile MigProfile +} -// NewDevice builds a new Device from an nvml.Device -func (m *miglib) NewDevice(d nvml.Device) (Device, error) { - isMig, ret := d.IsMigDeviceHandle() +var _ MigDevice = &migdevice{} + +// NewMigDevice builds a new MigDevice from an nvml.Device +func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) { + isMig, ret := handle.IsMigDeviceHandle() if ret != nvml.SUCCESS { return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret) } if !isMig { return nil, fmt.Errorf("not a MIG device") } - return &device{d, m, nil}, nil + return &migdevice{handle, d, nil}, nil } // GetProfile returns the MIG profile associated with a MIG device -func (d *device) GetProfile() (Profile, error) { - if d.profile != nil { - return d.profile, nil +func (m *migdevice) GetProfile() (MigProfile, error) { + if m.profile != nil { + return m.profile, nil } - parent, ret := d.Device.GetDeviceHandleFromMigDeviceHandle() + parent, ret := m.Device.GetDeviceHandleFromMigDeviceHandle() if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting parent device handle: %v", ret) } @@ -63,17 +64,17 @@ func (d *device) GetProfile() (Profile, error) { return nil, fmt.Errorf("error getting parent memory info: %v", ret) } - attributes, ret := d.Device.GetAttributes() + attributes, ret := m.Device.GetAttributes() if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting MIG device attributes: %v", ret) } - giID, ret := d.Device.GetGpuInstanceId() + giID, ret := m.Device.GetGpuInstanceId() if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret) } - ciID, ret := d.Device.GetComputeInstanceId() + ciID, ret := m.Device.GetComputeInstanceId() if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret) } @@ -120,12 +121,12 @@ func (d *device) GetProfile() (Profile, error) { continue } - p, err := d.miglib.NewProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total) + p, err := m.lib.NewMigProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total) if err != nil { return nil, fmt.Errorf("error creating MIG profile: %v", err) } - d.profile = p + m.profile = p return p, nil } } diff --git a/pkg/nvlib/mig/profile.go b/pkg/nvlib/device/mig_profile.go similarity index 84% rename from pkg/nvlib/mig/profile.go rename to pkg/nvlib/device/mig_profile.go index f5b548e..5aa00e6 100644 --- a/pkg/nvlib/mig/profile.go +++ b/pkg/nvlib/device/mig_profile.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package mig +package device import ( "fmt" @@ -30,16 +30,16 @@ const ( AttributeMediaExtensions = "me" ) -// Profile represents a specific MIG profile. +// MigProfile represents a specific MIG profile. // Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc. -type Profile interface { +type MigProfile interface { String() string - GetInfo() ProfileInfo - Equals(other Profile) bool + GetInfo() MigProfileInfo + Equals(other MigProfile) bool } -// ProfileInfo holds all info associated with a specific MIG profile -type ProfileInfo struct { +// MigProfileInfo holds all info associated with a specific MIG profile +type MigProfileInfo struct { C int G int GB int @@ -49,10 +49,10 @@ type ProfileInfo struct { CIEngProfileID int } -var _ Profile = &ProfileInfo{} +var _ MigProfile = &MigProfileInfo{} // NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it. -func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (Profile, error) { +func (d *devicelib) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) { giSlices := 0 switch giProfileID { case nvml.GPU_INSTANCE_PROFILE_1_SLICE: @@ -101,7 +101,7 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem attrs = append(attrs, AttributeMediaExtensions) } - p := &ProfileInfo{ + p := &MigProfileInfo{ C: ciSlices, G: giSlices, GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)), @@ -114,8 +114,8 @@ func (m *miglib) NewProfile(giProfileID, ciProfileID, ciEngProfileID int, migMem return p, nil } -// ParseProfile converts a string representation of a Profile into an object. -func (m *miglib) ParseProfile(profile string) (Profile, error) { +// ParseMigProfile converts a string representation of a MigProfile into an object +func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) { var err error var c, g, gb int var attrs []string @@ -126,18 +126,18 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) { split := strings.SplitN(profile, "+", 2) if len(split) == 2 { - attrs, err = parseProfileAttributes(split[1]) + attrs, err = parseMigProfileAttributes(split[1]) if err != nil { return nil, fmt.Errorf("error parsing attributes following '+' in Profile string: %v", err) } } - c, g, gb, err = parseProfileFields(split[0]) + c, g, gb, err = parseMigProfileFields(split[0]) if err != nil { return nil, fmt.Errorf("error parsing '.' separated fields in Profile string: %v", err) } - p := &ProfileInfo{ + p := &MigProfileInfo{ C: c, G: g, GB: gb, @@ -197,7 +197,7 @@ func (m *miglib) ParseProfile(profile string) (Profile, error) { } // String returns the string representation of a Profile -func (p *ProfileInfo) String() string { +func (p *MigProfileInfo) String() string { var suffix string if len(p.Attributes) > 0 { suffix = "+" + strings.Join(p.Attributes, ",") @@ -209,14 +209,14 @@ func (p *ProfileInfo) String() string { } // GetInfo returns detailed info about a Profile -func (p *ProfileInfo) GetInfo() ProfileInfo { +func (p *MigProfileInfo) GetInfo() MigProfileInfo { return *p } // Equals checks if two Profiles are identical or not -func (p *ProfileInfo) Equals(other Profile) bool { +func (p *MigProfileInfo) Equals(other MigProfile) bool { switch o := other.(type) { - case *ProfileInfo: + case *MigProfileInfo: if p.C != o.C { return false } @@ -240,7 +240,7 @@ func (p *ProfileInfo) Equals(other Profile) bool { return false } -func parseProfileField(s string, field string) (int, error) { +func parseMigProfileField(s string, field string) (int, error) { if strings.TrimSpace(s) != s { return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field) } @@ -257,32 +257,32 @@ func parseProfileField(s string, field string) (int, error) { return v, nil } -func parseProfileFields(s string) (int, int, int, error) { +func parseMigProfileFields(s string) (int, int, int, error) { var err error var c, g, gb int split := strings.SplitN(s, ".", 3) if len(split) == 3 { - c, err = parseProfileField(split[0], "c") + c, err = parseMigProfileField(split[0], "c") if err != nil { return -1, -1, -1, err } - g, err = parseProfileField(split[1], "g") + g, err = parseMigProfileField(split[1], "g") if err != nil { return -1, -1, -1, err } - gb, err = parseProfileField(split[2], "gb") + gb, err = parseMigProfileField(split[2], "gb") if err != nil { return -1, -1, -1, err } return c, g, gb, err } if len(split) == 2 { - g, err = parseProfileField(split[0], "g") + g, err = parseMigProfileField(split[0], "g") if err != nil { return -1, -1, -1, err } - gb, err = parseProfileField(split[1], "gb") + gb, err = parseMigProfileField(split[1], "gb") if err != nil { return -1, -1, -1, err } @@ -292,7 +292,7 @@ func parseProfileFields(s string) (int, int, int, error) { return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3") } -func parseProfileAttributes(s string) ([]string, error) { +func parseMigProfileAttributes(s string) ([]string, error) { attr := strings.Split(s, ",") if len(attr) == 0 { return nil, fmt.Errorf("empty attribute list") diff --git a/pkg/nvlib/mig/profile_test.go b/pkg/nvlib/device/mig_profile_test.go similarity index 98% rename from pkg/nvlib/mig/profile_test.go rename to pkg/nvlib/device/mig_profile_test.go index a73f07f..ab19773 100644 --- a/pkg/nvlib/mig/profile_test.go +++ b/pkg/nvlib/device/mig_profile_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package mig +package device import ( "fmt" @@ -256,10 +256,10 @@ func TestParseMigProfile(t *testing.T) { }, } - m := New() + d := New() for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - _, err := m.ParseProfile(tc.device) + _, err := d.ParseMigProfile(tc.device) if tc.valid { require.Nil(t, err) } else {