diff --git a/pkg/nvlib/device/api.go b/pkg/nvlib/device/api.go index cbbd4ae..1643fcc 100644 --- a/pkg/nvlib/device/api.go +++ b/pkg/nvlib/device/api.go @@ -26,7 +26,9 @@ type Interface interface { GetMigDevices() ([]MigDevice, error) GetMigProfiles() ([]MigProfile, error) NewDevice(d nvml.Device) (Device, error) + NewDeviceByUUID(uuid string) (Device, error) NewMigDevice(d nvml.Device) (MigDevice, error) + NewMigDeviceByUUID(uuid string) (MigDevice, error) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) ParseMigProfile(profile string) (MigProfile, error) VisitDevices(func(i int, d Device) error) error diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index bcc1409..3d549e4 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -46,6 +46,15 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) { return d.newDevice(dev) } +// NewDeviceByUUID builds a new Device from a UUID +func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) { + dev, ret := d.nvml.DeviceGetHandleByUUID(uuid) + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret) + } + return d.newDevice(dev) +} + // newDevice creates a device from an nvml.Device func (d *devicelib) newDevice(dev nvml.Device) (*device, error) { return &device{dev, d}, nil diff --git a/pkg/nvlib/device/mig_device.go b/pkg/nvlib/device/mig_device.go index 46af41e..0742a71 100644 --- a/pkg/nvlib/device/mig_device.go +++ b/pkg/nvlib/device/mig_device.go @@ -48,6 +48,15 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) { return &migdevice{handle, d, nil}, nil } +// NewMigDeviceByUUID builds a new MigDevice from a UUID +func (d *devicelib) NewMigDeviceByUUID(uuid string) (MigDevice, error) { + dev, ret := d.nvml.DeviceGetHandleByUUID(uuid) + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret) + } + return d.NewMigDevice(dev) +} + // GetProfile returns the MIG profile associated with a MIG device func (m *migdevice) GetProfile() (MigProfile, error) { if m.profile != nil {