diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 90c3e67..faaac25 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -88,6 +88,16 @@ func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileIn return GpuInstanceProfileInfo(p), Return(r) } +// GetGpuInstancePossiblePlacements returns the possible placements of a GPU Instance +func (d nvmlDevice) GetGpuInstancePossiblePlacements(info *GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return) { + nvmlPlacements, r := nvml.Device(d).GetGpuInstancePossiblePlacements((*nvml.GpuInstanceProfileInfo)(info)) + var placements []GpuInstancePlacement + for _, p := range nvmlPlacements { + placements = append(placements, GpuInstancePlacement(p)) + } + return placements, Return(r) +} + // GetGpuInstances returns the set of GPU Instances associated with a Device func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) { nvmlGis, r := nvml.Device(d).GetGpuInstances((*nvml.GpuInstanceProfileInfo)(info)) @@ -98,6 +108,12 @@ func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance return gis, Return(r) } +// CreateGpuInstanceWithPlacement creates a GPU Instance with a specific placement +func (d nvmlDevice) CreateGpuInstanceWithPlacement(info *GpuInstanceProfileInfo, placement *GpuInstancePlacement) (GpuInstance, Return) { + gi, r := nvml.Device(d).CreateGpuInstanceWithPlacement((*nvml.GpuInstanceProfileInfo)(info), (*nvml.GpuInstancePlacement)(placement)) + return nvmlGpuInstance(gi), Return(r) +} + // GetMaxMigDeviceCount returns the maximum number of MIG devices that can be created on a Device func (d nvmlDevice) GetMaxMigDeviceCount() (int, Return) { m, r := nvml.Device(d).GetMaxMigDeviceCount() diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index 4379c68..b1c97c0 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -39,12 +39,14 @@ type Interface interface { // //go:generate moq -out device_mock.go . Device type Device interface { + CreateGpuInstanceWithPlacement(*GpuInstanceProfileInfo, *GpuInstancePlacement) (GpuInstance, Return) GetAttributes() (DeviceAttributes, Return) GetComputeInstanceId() (int, Return) GetCudaComputeCapability() (int, int, Return) GetDeviceHandleFromMigDeviceHandle() (Device, Return) GetGpuInstanceById(ID int) (GpuInstance, Return) GetGpuInstanceId() (int, Return) + GetGpuInstancePossiblePlacements(*GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return) GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return) GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) GetIndex() (int, Return)