From 6a4886e49e964d9ca6de726dcd63084343b2e9c2 Mon Sep 17 00:00:00 2001
From: Kevin Klues <kklues@nvidia.com>
Date: Thu, 8 Dec 2022 14:47:10 +0000
Subject: [PATCH] Add Placement related calls for GPUInstances in nvml wrapper

Signed-off-by: Kevin Klues <kklues@nvidia.com>
---
 pkg/nvml/device.go | 16 ++++++++++++++++
 pkg/nvml/types.go  |  2 ++
 2 files changed, 18 insertions(+)

diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go
index 90c3e67..faaac25 100644
--- a/pkg/nvml/device.go
+++ b/pkg/nvml/device.go
@@ -88,6 +88,16 @@ func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileIn
 	return GpuInstanceProfileInfo(p), Return(r)
 }
 
+// GetGpuInstancePossiblePlacements returns the possible placements of a GPU Instance
+func (d nvmlDevice) GetGpuInstancePossiblePlacements(info *GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return) {
+	nvmlPlacements, r := nvml.Device(d).GetGpuInstancePossiblePlacements((*nvml.GpuInstanceProfileInfo)(info))
+	var placements []GpuInstancePlacement
+	for _, p := range nvmlPlacements {
+		placements = append(placements, GpuInstancePlacement(p))
+	}
+	return placements, Return(r)
+}
+
 // GetGpuInstances returns the set of GPU Instances associated with a Device
 func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
 	nvmlGis, r := nvml.Device(d).GetGpuInstances((*nvml.GpuInstanceProfileInfo)(info))
@@ -98,6 +108,12 @@ func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance
 	return gis, Return(r)
 }
 
+// CreateGpuInstanceWithPlacement creates a GPU Instance with a specific placement
+func (d nvmlDevice) CreateGpuInstanceWithPlacement(info *GpuInstanceProfileInfo, placement *GpuInstancePlacement) (GpuInstance, Return) {
+	gi, r := nvml.Device(d).CreateGpuInstanceWithPlacement((*nvml.GpuInstanceProfileInfo)(info), (*nvml.GpuInstancePlacement)(placement))
+	return nvmlGpuInstance(gi), Return(r)
+}
+
 // GetMaxMigDeviceCount returns the maximum number of MIG devices that can be created on a Device
 func (d nvmlDevice) GetMaxMigDeviceCount() (int, Return) {
 	m, r := nvml.Device(d).GetMaxMigDeviceCount()
diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go
index 4379c68..b1c97c0 100644
--- a/pkg/nvml/types.go
+++ b/pkg/nvml/types.go
@@ -39,12 +39,14 @@ type Interface interface {
 //
 //go:generate moq -out device_mock.go . Device
 type Device interface {
+	CreateGpuInstanceWithPlacement(*GpuInstanceProfileInfo, *GpuInstancePlacement) (GpuInstance, Return)
 	GetAttributes() (DeviceAttributes, Return)
 	GetComputeInstanceId() (int, Return)
 	GetCudaComputeCapability() (int, int, Return)
 	GetDeviceHandleFromMigDeviceHandle() (Device, Return)
 	GetGpuInstanceById(ID int) (GpuInstance, Return)
 	GetGpuInstanceId() (int, Return)
+	GetGpuInstancePossiblePlacements(*GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return)
 	GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return)
 	GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return)
 	GetIndex() (int, Return)