From 6a4886e49e964d9ca6de726dcd63084343b2e9c2 Mon Sep 17 00:00:00 2001 From: Kevin Klues <kklues@nvidia.com> Date: Thu, 8 Dec 2022 14:47:10 +0000 Subject: [PATCH] Add Placement related calls for GPUInstances in nvml wrapper Signed-off-by: Kevin Klues <kklues@nvidia.com> --- pkg/nvml/device.go | 16 ++++++++++++++++ pkg/nvml/types.go | 2 ++ 2 files changed, 18 insertions(+) diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 90c3e67..faaac25 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -88,6 +88,16 @@ func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileIn return GpuInstanceProfileInfo(p), Return(r) } +// GetGpuInstancePossiblePlacements returns the possible placements of a GPU Instance +func (d nvmlDevice) GetGpuInstancePossiblePlacements(info *GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return) { + nvmlPlacements, r := nvml.Device(d).GetGpuInstancePossiblePlacements((*nvml.GpuInstanceProfileInfo)(info)) + var placements []GpuInstancePlacement + for _, p := range nvmlPlacements { + placements = append(placements, GpuInstancePlacement(p)) + } + return placements, Return(r) +} + // GetGpuInstances returns the set of GPU Instances associated with a Device func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) { nvmlGis, r := nvml.Device(d).GetGpuInstances((*nvml.GpuInstanceProfileInfo)(info)) @@ -98,6 +108,12 @@ func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance return gis, Return(r) } +// CreateGpuInstanceWithPlacement creates a GPU Instance with a specific placement +func (d nvmlDevice) CreateGpuInstanceWithPlacement(info *GpuInstanceProfileInfo, placement *GpuInstancePlacement) (GpuInstance, Return) { + gi, r := nvml.Device(d).CreateGpuInstanceWithPlacement((*nvml.GpuInstanceProfileInfo)(info), (*nvml.GpuInstancePlacement)(placement)) + return nvmlGpuInstance(gi), Return(r) +} + // GetMaxMigDeviceCount returns the maximum number of MIG devices that can be created on a Device func (d nvmlDevice) GetMaxMigDeviceCount() (int, Return) { m, r := nvml.Device(d).GetMaxMigDeviceCount() diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index 4379c68..b1c97c0 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -39,12 +39,14 @@ type Interface interface { // //go:generate moq -out device_mock.go . Device type Device interface { + CreateGpuInstanceWithPlacement(*GpuInstanceProfileInfo, *GpuInstancePlacement) (GpuInstance, Return) GetAttributes() (DeviceAttributes, Return) GetComputeInstanceId() (int, Return) GetCudaComputeCapability() (int, int, Return) GetDeviceHandleFromMigDeviceHandle() (Device, Return) GetGpuInstanceById(ID int) (GpuInstance, Return) GetGpuInstanceId() (int, Return) + GetGpuInstancePossiblePlacements(*GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return) GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return) GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) GetIndex() (int, Return)