From 008aa70bc688f2e86dbbdc470c321cc9613145c0 Mon Sep 17 00:00:00 2001 From: Kevin Klues Date: Wed, 10 Aug 2022 06:57:54 +0000 Subject: [PATCH] Add an interface based wrapper around go-nvml for better mocking Signed-off-by: Kevin Klues --- pkg/nvml/ci.go | 44 +++ pkg/nvml/ci_mock.go | 102 +++++++ pkg/nvml/consts.go | 87 ++++++ pkg/nvml/device.go | 117 ++++++++ pkg/nvml/device_mock.go | 598 ++++++++++++++++++++++++++++++++++++++++ pkg/nvml/gi.go | 65 +++++ pkg/nvml/gi_mock.go | 237 ++++++++++++++++ pkg/nvml/nvml.go | 69 +++++ pkg/nvml/nvml_mock.go | 303 ++++++++++++++++++++ pkg/nvml/types.go | 108 ++++++++ 10 files changed, 1730 insertions(+) create mode 100644 pkg/nvml/ci.go create mode 100644 pkg/nvml/ci_mock.go create mode 100644 pkg/nvml/consts.go create mode 100644 pkg/nvml/device.go create mode 100644 pkg/nvml/device_mock.go create mode 100644 pkg/nvml/gi.go create mode 100644 pkg/nvml/gi_mock.go create mode 100644 pkg/nvml/nvml.go create mode 100644 pkg/nvml/nvml_mock.go create mode 100644 pkg/nvml/types.go diff --git a/pkg/nvml/ci.go b/pkg/nvml/ci.go new file mode 100644 index 0000000..6a9d798 --- /dev/null +++ b/pkg/nvml/ci.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nvml + +import ( + "github.com/NVIDIA/go-nvml/pkg/nvml" +) + +type nvmlComputeInstance nvml.ComputeInstance + +var _ ComputeInstance = (*nvmlComputeInstance)(nil) + +// GetInfo() returns info about a Compute Instance +func (ci nvmlComputeInstance) GetInfo() (ComputeInstanceInfo, Return) { + i, r := nvml.ComputeInstance(ci).GetInfo() + info := ComputeInstanceInfo{ + Device: nvmlDevice(i.Device), + GpuInstance: nvmlGpuInstance(i.GpuInstance), + Id: i.Id, + ProfileId: i.ProfileId, + Placement: ComputeInstancePlacement(i.Placement), + } + return info, Return(r) +} + +// Destroy() destroys a Compute Instance +func (ci nvmlComputeInstance) Destroy() Return { + r := nvml.ComputeInstance(ci).Destroy() + return Return(r) +} diff --git a/pkg/nvml/ci_mock.go b/pkg/nvml/ci_mock.go new file mode 100644 index 0000000..af323b3 --- /dev/null +++ b/pkg/nvml/ci_mock.go @@ -0,0 +1,102 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package nvml + +import ( + "sync" +) + +// Ensure, that ComputeInstanceMock does implement ComputeInstance. +// If this is not the case, regenerate this file with moq. +var _ ComputeInstance = &ComputeInstanceMock{} + +// ComputeInstanceMock is a mock implementation of ComputeInstance. +// +// func TestSomethingThatUsesComputeInstance(t *testing.T) { +// +// // make and configure a mocked ComputeInstance +// mockedComputeInstance := &ComputeInstanceMock{ +// DestroyFunc: func() Return { +// panic("mock out the Destroy method") +// }, +// GetInfoFunc: func() (ComputeInstanceInfo, Return) { +// panic("mock out the GetInfo method") +// }, +// } +// +// // use mockedComputeInstance in code that requires ComputeInstance +// // and then make assertions. +// +// } +type ComputeInstanceMock struct { + // DestroyFunc mocks the Destroy method. + DestroyFunc func() Return + + // GetInfoFunc mocks the GetInfo method. + GetInfoFunc func() (ComputeInstanceInfo, Return) + + // calls tracks calls to the methods. + calls struct { + // Destroy holds details about calls to the Destroy method. + Destroy []struct { + } + // GetInfo holds details about calls to the GetInfo method. + GetInfo []struct { + } + } + lockDestroy sync.RWMutex + lockGetInfo sync.RWMutex +} + +// Destroy calls DestroyFunc. +func (mock *ComputeInstanceMock) Destroy() Return { + if mock.DestroyFunc == nil { + panic("ComputeInstanceMock.DestroyFunc: method is nil but ComputeInstance.Destroy was just called") + } + callInfo := struct { + }{} + mock.lockDestroy.Lock() + mock.calls.Destroy = append(mock.calls.Destroy, callInfo) + mock.lockDestroy.Unlock() + return mock.DestroyFunc() +} + +// DestroyCalls gets all the calls that were made to Destroy. +// Check the length with: +// len(mockedComputeInstance.DestroyCalls()) +func (mock *ComputeInstanceMock) DestroyCalls() []struct { +} { + var calls []struct { + } + mock.lockDestroy.RLock() + calls = mock.calls.Destroy + mock.lockDestroy.RUnlock() + return calls +} + +// GetInfo calls GetInfoFunc. +func (mock *ComputeInstanceMock) GetInfo() (ComputeInstanceInfo, Return) { + if mock.GetInfoFunc == nil { + panic("ComputeInstanceMock.GetInfoFunc: method is nil but ComputeInstance.GetInfo was just called") + } + callInfo := struct { + }{} + mock.lockGetInfo.Lock() + mock.calls.GetInfo = append(mock.calls.GetInfo, callInfo) + mock.lockGetInfo.Unlock() + return mock.GetInfoFunc() +} + +// GetInfoCalls gets all the calls that were made to GetInfo. +// Check the length with: +// len(mockedComputeInstance.GetInfoCalls()) +func (mock *ComputeInstanceMock) GetInfoCalls() []struct { +} { + var calls []struct { + } + mock.lockGetInfo.RLock() + calls = mock.calls.GetInfo + mock.lockGetInfo.RUnlock() + return calls +} diff --git a/pkg/nvml/consts.go b/pkg/nvml/consts.go new file mode 100644 index 0000000..2ded48b --- /dev/null +++ b/pkg/nvml/consts.go @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nvml + +import ( + "github.com/NVIDIA/go-nvml/pkg/nvml" +) + +// Return constants +const ( + SUCCESS = Return(nvml.SUCCESS) + ERROR_UNINITIALIZED = Return(nvml.ERROR_UNINITIALIZED) + ERROR_INVALID_ARGUMENT = Return(nvml.ERROR_INVALID_ARGUMENT) + ERROR_NOT_SUPPORTED = Return(nvml.ERROR_NOT_SUPPORTED) + ERROR_NO_PERMISSION = Return(nvml.ERROR_NO_PERMISSION) + ERROR_ALREADY_INITIALIZED = Return(nvml.ERROR_ALREADY_INITIALIZED) + ERROR_NOT_FOUND = Return(nvml.ERROR_NOT_FOUND) + ERROR_INSUFFICIENT_SIZE = Return(nvml.ERROR_INSUFFICIENT_SIZE) + ERROR_INSUFFICIENT_POWER = Return(nvml.ERROR_INSUFFICIENT_POWER) + ERROR_DRIVER_NOT_LOADED = Return(nvml.ERROR_DRIVER_NOT_LOADED) + ERROR_TIMEOUT = Return(nvml.ERROR_TIMEOUT) + ERROR_IRQ_ISSUE = Return(nvml.ERROR_IRQ_ISSUE) + ERROR_LIBRARY_NOT_FOUND = Return(nvml.ERROR_LIBRARY_NOT_FOUND) + ERROR_FUNCTION_NOT_FOUND = Return(nvml.ERROR_FUNCTION_NOT_FOUND) + ERROR_CORRUPTED_INFOROM = Return(nvml.ERROR_CORRUPTED_INFOROM) + ERROR_GPU_IS_LOST = Return(nvml.ERROR_GPU_IS_LOST) + ERROR_RESET_REQUIRED = Return(nvml.ERROR_RESET_REQUIRED) + ERROR_OPERATING_SYSTEM = Return(nvml.ERROR_OPERATING_SYSTEM) + ERROR_LIB_RM_VERSION_MISMATCH = Return(nvml.ERROR_LIB_RM_VERSION_MISMATCH) + ERROR_IN_USE = Return(nvml.ERROR_IN_USE) + ERROR_MEMORY = Return(nvml.ERROR_MEMORY) + ERROR_NO_DATA = Return(nvml.ERROR_NO_DATA) + ERROR_VGPU_ECC_NOT_SUPPORTED = Return(nvml.ERROR_VGPU_ECC_NOT_SUPPORTED) + ERROR_INSUFFICIENT_RESOURCES = Return(nvml.ERROR_INSUFFICIENT_RESOURCES) + ERROR_UNKNOWN = Return(nvml.ERROR_UNKNOWN) +) + +// MIG Mode constants +const ( + DEVICE_MIG_ENABLE = nvml.DEVICE_MIG_ENABLE + DEVICE_MIG_DISABLE = nvml.DEVICE_MIG_DISABLE +) + +// GPU Instance Profiles +const ( + GPU_INSTANCE_PROFILE_1_SLICE = nvml.GPU_INSTANCE_PROFILE_1_SLICE + GPU_INSTANCE_PROFILE_2_SLICE = nvml.GPU_INSTANCE_PROFILE_2_SLICE + GPU_INSTANCE_PROFILE_3_SLICE = nvml.GPU_INSTANCE_PROFILE_3_SLICE + GPU_INSTANCE_PROFILE_4_SLICE = nvml.GPU_INSTANCE_PROFILE_4_SLICE + GPU_INSTANCE_PROFILE_6_SLICE = nvml.GPU_INSTANCE_PROFILE_6_SLICE + GPU_INSTANCE_PROFILE_7_SLICE = nvml.GPU_INSTANCE_PROFILE_7_SLICE + GPU_INSTANCE_PROFILE_8_SLICE = nvml.GPU_INSTANCE_PROFILE_8_SLICE + GPU_INSTANCE_PROFILE_1_SLICE_REV1 = nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1 + GPU_INSTANCE_PROFILE_COUNT = nvml.GPU_INSTANCE_PROFILE_COUNT +) + +// Compute Instance Profiles +const ( + COMPUTE_INSTANCE_PROFILE_1_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE + COMPUTE_INSTANCE_PROFILE_2_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_2_SLICE + COMPUTE_INSTANCE_PROFILE_3_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_3_SLICE + COMPUTE_INSTANCE_PROFILE_4_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_4_SLICE + COMPUTE_INSTANCE_PROFILE_6_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_6_SLICE + COMPUTE_INSTANCE_PROFILE_7_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_7_SLICE + COMPUTE_INSTANCE_PROFILE_8_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_8_SLICE + COMPUTE_INSTANCE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_PROFILE_COUNT +) + +// Compute Instance Engine Profiles +const ( + COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED + COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT +) diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go new file mode 100644 index 0000000..9b0caf8 --- /dev/null +++ b/pkg/nvml/device.go @@ -0,0 +1,117 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvml + +import "github.com/NVIDIA/go-nvml/pkg/nvml" + +type nvmlDevice nvml.Device + +var _ Device = (*nvmlDevice)(nil) + +// GetIndex returns the index of a Device +func (d nvmlDevice) GetIndex() (int, Return) { + i, r := nvml.Device(d).GetIndex() + return i, Return(r) +} + +// GetPciInfo returns the PCI info of a Device +func (d nvmlDevice) GetPciInfo() (PciInfo, Return) { + p, r := nvml.Device(d).GetPciInfo() + return PciInfo(p), Return(r) +} + +// GetMemoryInfo returns the memory info of a Device +func (d nvmlDevice) GetMemoryInfo() (Memory, Return) { + p, r := nvml.Device(d).GetMemoryInfo() + return Memory(p), Return(r) +} + +// GetUUID returns the UUID of a Device +func (d nvmlDevice) GetUUID() (string, Return) { + u, r := nvml.Device(d).GetUUID() + return u, Return(r) +} + +// GetMinorNumber returns the minor number of a Device +func (d nvmlDevice) GetMinorNumber() (int, Return) { + m, r := nvml.Device(d).GetMinorNumber() + return m, Return(r) +} + +// IsMigDeviceHandle returns whether a Device is a MIG device or not +func (d nvmlDevice) IsMigDeviceHandle() (bool, Return) { + b, r := nvml.Device(d).IsMigDeviceHandle() + return b, Return(r) +} + +// GetDeviceHandleFromMigDeviceHandle returns the parent Device of a MIG device +func (d nvmlDevice) GetDeviceHandleFromMigDeviceHandle() (Device, Return) { + p, r := nvml.Device(d).GetDeviceHandleFromMigDeviceHandle() + return nvmlDevice(p), Return(r) +} + +// SetMigMode sets the MIG mode of a Device +func (d nvmlDevice) SetMigMode(mode int) (Return, Return) { + r1, r2 := nvml.Device(d).SetMigMode(mode) + return Return(r1), Return(r2) +} + +// GetMigMode returns the MIG mode of a Device +func (d nvmlDevice) GetMigMode() (int, int, Return) { + s1, s2, r := nvml.Device(d).GetMigMode() + return s1, s2, Return(r) +} + +// GetGpuInstanceProfileInfo returns the profile info of a GPU Instance +func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileInfo, Return) { + p, r := nvml.Device(d).GetGpuInstanceProfileInfo(profile) + return GpuInstanceProfileInfo(p), Return(r) +} + +// GetGpuInstances returns the set of GPU Instances associated with a Device +func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) { + nvmlGis, r := nvml.Device(d).GetGpuInstances((*nvml.GpuInstanceProfileInfo)(info)) + var gis []GpuInstance + for _, gi := range nvmlGis { + gis = append(gis, nvmlGpuInstance(gi)) + } + return gis, Return(r) +} + +// GetMaxMigDeviceCount returns the maximum number of MIG devices that can be created on a Device +func (d nvmlDevice) GetMaxMigDeviceCount() (int, Return) { + m, r := nvml.Device(d).GetMaxMigDeviceCount() + return m, Return(r) +} + +// GetMigDeviceHandleByIndex returns the handle to a MIG device given its index +func (d nvmlDevice) GetMigDeviceHandleByIndex(Index int) (Device, Return) { + h, r := nvml.Device(d).GetMigDeviceHandleByIndex(Index) + return nvmlDevice(h), Return(r) +} + +// GetGpuInstanceId returns the GPU Instance ID of a MIG device +func (d nvmlDevice) GetGpuInstanceId() (int, Return) { + gi, r := nvml.Device(d).GetGpuInstanceId() + return gi, Return(r) +} + +// GetComputeInstanceId returns the Compute Instance ID of a MIG device +func (d nvmlDevice) GetComputeInstanceId() (int, Return) { + ci, r := nvml.Device(d).GetComputeInstanceId() + return ci, Return(r) +} diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go new file mode 100644 index 0000000..a25898e --- /dev/null +++ b/pkg/nvml/device_mock.go @@ -0,0 +1,598 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package nvml + +import ( + "sync" +) + +// Ensure, that DeviceMock does implement Device. +// If this is not the case, regenerate this file with moq. +var _ Device = &DeviceMock{} + +// DeviceMock is a mock implementation of Device. +// +// func TestSomethingThatUsesDevice(t *testing.T) { +// +// // make and configure a mocked Device +// mockedDevice := &DeviceMock{ +// GetComputeInstanceIdFunc: func() (int, Return) { +// panic("mock out the GetComputeInstanceId method") +// }, +// GetDeviceHandleFromMigDeviceHandleFunc: func() (Device, Return) { +// panic("mock out the GetDeviceHandleFromMigDeviceHandle method") +// }, +// GetGpuInstanceIdFunc: func() (int, Return) { +// panic("mock out the GetGpuInstanceId method") +// }, +// GetGpuInstanceProfileInfoFunc: func(Profile int) (GpuInstanceProfileInfo, Return) { +// panic("mock out the GetGpuInstanceProfileInfo method") +// }, +// GetGpuInstancesFunc: func(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) { +// panic("mock out the GetGpuInstances method") +// }, +// GetIndexFunc: func() (int, Return) { +// panic("mock out the GetIndex method") +// }, +// GetMaxMigDeviceCountFunc: func() (int, Return) { +// panic("mock out the GetMaxMigDeviceCount method") +// }, +// GetMemoryInfoFunc: func() (Memory, Return) { +// panic("mock out the GetMemoryInfo method") +// }, +// GetMigDeviceHandleByIndexFunc: func(Index int) (Device, Return) { +// panic("mock out the GetMigDeviceHandleByIndex method") +// }, +// GetMigModeFunc: func() (int, int, Return) { +// panic("mock out the GetMigMode method") +// }, +// GetMinorNumberFunc: func() (int, Return) { +// panic("mock out the GetMinorNumber method") +// }, +// GetPciInfoFunc: func() (PciInfo, Return) { +// panic("mock out the GetPciInfo method") +// }, +// GetUUIDFunc: func() (string, Return) { +// panic("mock out the GetUUID method") +// }, +// IsMigDeviceHandleFunc: func() (bool, Return) { +// panic("mock out the IsMigDeviceHandle method") +// }, +// SetMigModeFunc: func(Mode int) (Return, Return) { +// panic("mock out the SetMigMode method") +// }, +// } +// +// // use mockedDevice in code that requires Device +// // and then make assertions. +// +// } +type DeviceMock struct { + // GetComputeInstanceIdFunc mocks the GetComputeInstanceId method. + GetComputeInstanceIdFunc func() (int, Return) + + // GetDeviceHandleFromMigDeviceHandleFunc mocks the GetDeviceHandleFromMigDeviceHandle method. + GetDeviceHandleFromMigDeviceHandleFunc func() (Device, Return) + + // GetGpuInstanceIdFunc mocks the GetGpuInstanceId method. + GetGpuInstanceIdFunc func() (int, Return) + + // GetGpuInstanceProfileInfoFunc mocks the GetGpuInstanceProfileInfo method. + GetGpuInstanceProfileInfoFunc func(Profile int) (GpuInstanceProfileInfo, Return) + + // GetGpuInstancesFunc mocks the GetGpuInstances method. + GetGpuInstancesFunc func(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) + + // GetIndexFunc mocks the GetIndex method. + GetIndexFunc func() (int, Return) + + // GetMaxMigDeviceCountFunc mocks the GetMaxMigDeviceCount method. + GetMaxMigDeviceCountFunc func() (int, Return) + + // GetMemoryInfoFunc mocks the GetMemoryInfo method. + GetMemoryInfoFunc func() (Memory, Return) + + // GetMigDeviceHandleByIndexFunc mocks the GetMigDeviceHandleByIndex method. + GetMigDeviceHandleByIndexFunc func(Index int) (Device, Return) + + // GetMigModeFunc mocks the GetMigMode method. + GetMigModeFunc func() (int, int, Return) + + // GetMinorNumberFunc mocks the GetMinorNumber method. + GetMinorNumberFunc func() (int, Return) + + // GetPciInfoFunc mocks the GetPciInfo method. + GetPciInfoFunc func() (PciInfo, Return) + + // GetUUIDFunc mocks the GetUUID method. + GetUUIDFunc func() (string, Return) + + // IsMigDeviceHandleFunc mocks the IsMigDeviceHandle method. + IsMigDeviceHandleFunc func() (bool, Return) + + // SetMigModeFunc mocks the SetMigMode method. + SetMigModeFunc func(Mode int) (Return, Return) + + // calls tracks calls to the methods. + calls struct { + // GetComputeInstanceId holds details about calls to the GetComputeInstanceId method. + GetComputeInstanceId []struct { + } + // GetDeviceHandleFromMigDeviceHandle holds details about calls to the GetDeviceHandleFromMigDeviceHandle method. + GetDeviceHandleFromMigDeviceHandle []struct { + } + // GetGpuInstanceId holds details about calls to the GetGpuInstanceId method. + GetGpuInstanceId []struct { + } + // GetGpuInstanceProfileInfo holds details about calls to the GetGpuInstanceProfileInfo method. + GetGpuInstanceProfileInfo []struct { + // Profile is the Profile argument value. + Profile int + } + // GetGpuInstances holds details about calls to the GetGpuInstances method. + GetGpuInstances []struct { + // Info is the Info argument value. + Info *GpuInstanceProfileInfo + } + // GetIndex holds details about calls to the GetIndex method. + GetIndex []struct { + } + // GetMaxMigDeviceCount holds details about calls to the GetMaxMigDeviceCount method. + GetMaxMigDeviceCount []struct { + } + // GetMemoryInfo holds details about calls to the GetMemoryInfo method. + GetMemoryInfo []struct { + } + // GetMigDeviceHandleByIndex holds details about calls to the GetMigDeviceHandleByIndex method. + GetMigDeviceHandleByIndex []struct { + // Index is the Index argument value. + Index int + } + // GetMigMode holds details about calls to the GetMigMode method. + GetMigMode []struct { + } + // GetMinorNumber holds details about calls to the GetMinorNumber method. + GetMinorNumber []struct { + } + // GetPciInfo holds details about calls to the GetPciInfo method. + GetPciInfo []struct { + } + // GetUUID holds details about calls to the GetUUID method. + GetUUID []struct { + } + // IsMigDeviceHandle holds details about calls to the IsMigDeviceHandle method. + IsMigDeviceHandle []struct { + } + // SetMigMode holds details about calls to the SetMigMode method. + SetMigMode []struct { + // Mode is the Mode argument value. + Mode int + } + } + lockGetComputeInstanceId sync.RWMutex + lockGetDeviceHandleFromMigDeviceHandle sync.RWMutex + lockGetGpuInstanceId sync.RWMutex + lockGetGpuInstanceProfileInfo sync.RWMutex + lockGetGpuInstances sync.RWMutex + lockGetIndex sync.RWMutex + lockGetMaxMigDeviceCount sync.RWMutex + lockGetMemoryInfo sync.RWMutex + lockGetMigDeviceHandleByIndex sync.RWMutex + lockGetMigMode sync.RWMutex + lockGetMinorNumber sync.RWMutex + lockGetPciInfo sync.RWMutex + lockGetUUID sync.RWMutex + lockIsMigDeviceHandle sync.RWMutex + lockSetMigMode sync.RWMutex +} + +// GetComputeInstanceId calls GetComputeInstanceIdFunc. +func (mock *DeviceMock) GetComputeInstanceId() (int, Return) { + if mock.GetComputeInstanceIdFunc == nil { + panic("DeviceMock.GetComputeInstanceIdFunc: method is nil but Device.GetComputeInstanceId was just called") + } + callInfo := struct { + }{} + mock.lockGetComputeInstanceId.Lock() + mock.calls.GetComputeInstanceId = append(mock.calls.GetComputeInstanceId, callInfo) + mock.lockGetComputeInstanceId.Unlock() + return mock.GetComputeInstanceIdFunc() +} + +// GetComputeInstanceIdCalls gets all the calls that were made to GetComputeInstanceId. +// Check the length with: +// len(mockedDevice.GetComputeInstanceIdCalls()) +func (mock *DeviceMock) GetComputeInstanceIdCalls() []struct { +} { + var calls []struct { + } + mock.lockGetComputeInstanceId.RLock() + calls = mock.calls.GetComputeInstanceId + mock.lockGetComputeInstanceId.RUnlock() + return calls +} + +// GetDeviceHandleFromMigDeviceHandle calls GetDeviceHandleFromMigDeviceHandleFunc. +func (mock *DeviceMock) GetDeviceHandleFromMigDeviceHandle() (Device, Return) { + if mock.GetDeviceHandleFromMigDeviceHandleFunc == nil { + panic("DeviceMock.GetDeviceHandleFromMigDeviceHandleFunc: method is nil but Device.GetDeviceHandleFromMigDeviceHandle was just called") + } + callInfo := struct { + }{} + mock.lockGetDeviceHandleFromMigDeviceHandle.Lock() + mock.calls.GetDeviceHandleFromMigDeviceHandle = append(mock.calls.GetDeviceHandleFromMigDeviceHandle, callInfo) + mock.lockGetDeviceHandleFromMigDeviceHandle.Unlock() + return mock.GetDeviceHandleFromMigDeviceHandleFunc() +} + +// GetDeviceHandleFromMigDeviceHandleCalls gets all the calls that were made to GetDeviceHandleFromMigDeviceHandle. +// Check the length with: +// len(mockedDevice.GetDeviceHandleFromMigDeviceHandleCalls()) +func (mock *DeviceMock) GetDeviceHandleFromMigDeviceHandleCalls() []struct { +} { + var calls []struct { + } + mock.lockGetDeviceHandleFromMigDeviceHandle.RLock() + calls = mock.calls.GetDeviceHandleFromMigDeviceHandle + mock.lockGetDeviceHandleFromMigDeviceHandle.RUnlock() + return calls +} + +// GetGpuInstanceId calls GetGpuInstanceIdFunc. +func (mock *DeviceMock) GetGpuInstanceId() (int, Return) { + if mock.GetGpuInstanceIdFunc == nil { + panic("DeviceMock.GetGpuInstanceIdFunc: method is nil but Device.GetGpuInstanceId was just called") + } + callInfo := struct { + }{} + mock.lockGetGpuInstanceId.Lock() + mock.calls.GetGpuInstanceId = append(mock.calls.GetGpuInstanceId, callInfo) + mock.lockGetGpuInstanceId.Unlock() + return mock.GetGpuInstanceIdFunc() +} + +// GetGpuInstanceIdCalls gets all the calls that were made to GetGpuInstanceId. +// Check the length with: +// len(mockedDevice.GetGpuInstanceIdCalls()) +func (mock *DeviceMock) GetGpuInstanceIdCalls() []struct { +} { + var calls []struct { + } + mock.lockGetGpuInstanceId.RLock() + calls = mock.calls.GetGpuInstanceId + mock.lockGetGpuInstanceId.RUnlock() + return calls +} + +// GetGpuInstanceProfileInfo calls GetGpuInstanceProfileInfoFunc. +func (mock *DeviceMock) GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return) { + if mock.GetGpuInstanceProfileInfoFunc == nil { + panic("DeviceMock.GetGpuInstanceProfileInfoFunc: method is nil but Device.GetGpuInstanceProfileInfo was just called") + } + callInfo := struct { + Profile int + }{ + Profile: Profile, + } + mock.lockGetGpuInstanceProfileInfo.Lock() + mock.calls.GetGpuInstanceProfileInfo = append(mock.calls.GetGpuInstanceProfileInfo, callInfo) + mock.lockGetGpuInstanceProfileInfo.Unlock() + return mock.GetGpuInstanceProfileInfoFunc(Profile) +} + +// GetGpuInstanceProfileInfoCalls gets all the calls that were made to GetGpuInstanceProfileInfo. +// Check the length with: +// len(mockedDevice.GetGpuInstanceProfileInfoCalls()) +func (mock *DeviceMock) GetGpuInstanceProfileInfoCalls() []struct { + Profile int +} { + var calls []struct { + Profile int + } + mock.lockGetGpuInstanceProfileInfo.RLock() + calls = mock.calls.GetGpuInstanceProfileInfo + mock.lockGetGpuInstanceProfileInfo.RUnlock() + return calls +} + +// GetGpuInstances calls GetGpuInstancesFunc. +func (mock *DeviceMock) GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) { + if mock.GetGpuInstancesFunc == nil { + panic("DeviceMock.GetGpuInstancesFunc: method is nil but Device.GetGpuInstances was just called") + } + callInfo := struct { + Info *GpuInstanceProfileInfo + }{ + Info: Info, + } + mock.lockGetGpuInstances.Lock() + mock.calls.GetGpuInstances = append(mock.calls.GetGpuInstances, callInfo) + mock.lockGetGpuInstances.Unlock() + return mock.GetGpuInstancesFunc(Info) +} + +// GetGpuInstancesCalls gets all the calls that were made to GetGpuInstances. +// Check the length with: +// len(mockedDevice.GetGpuInstancesCalls()) +func (mock *DeviceMock) GetGpuInstancesCalls() []struct { + Info *GpuInstanceProfileInfo +} { + var calls []struct { + Info *GpuInstanceProfileInfo + } + mock.lockGetGpuInstances.RLock() + calls = mock.calls.GetGpuInstances + mock.lockGetGpuInstances.RUnlock() + return calls +} + +// GetIndex calls GetIndexFunc. +func (mock *DeviceMock) GetIndex() (int, Return) { + if mock.GetIndexFunc == nil { + panic("DeviceMock.GetIndexFunc: method is nil but Device.GetIndex was just called") + } + callInfo := struct { + }{} + mock.lockGetIndex.Lock() + mock.calls.GetIndex = append(mock.calls.GetIndex, callInfo) + mock.lockGetIndex.Unlock() + return mock.GetIndexFunc() +} + +// GetIndexCalls gets all the calls that were made to GetIndex. +// Check the length with: +// len(mockedDevice.GetIndexCalls()) +func (mock *DeviceMock) GetIndexCalls() []struct { +} { + var calls []struct { + } + mock.lockGetIndex.RLock() + calls = mock.calls.GetIndex + mock.lockGetIndex.RUnlock() + return calls +} + +// GetMaxMigDeviceCount calls GetMaxMigDeviceCountFunc. +func (mock *DeviceMock) GetMaxMigDeviceCount() (int, Return) { + if mock.GetMaxMigDeviceCountFunc == nil { + panic("DeviceMock.GetMaxMigDeviceCountFunc: method is nil but Device.GetMaxMigDeviceCount was just called") + } + callInfo := struct { + }{} + mock.lockGetMaxMigDeviceCount.Lock() + mock.calls.GetMaxMigDeviceCount = append(mock.calls.GetMaxMigDeviceCount, callInfo) + mock.lockGetMaxMigDeviceCount.Unlock() + return mock.GetMaxMigDeviceCountFunc() +} + +// GetMaxMigDeviceCountCalls gets all the calls that were made to GetMaxMigDeviceCount. +// Check the length with: +// len(mockedDevice.GetMaxMigDeviceCountCalls()) +func (mock *DeviceMock) GetMaxMigDeviceCountCalls() []struct { +} { + var calls []struct { + } + mock.lockGetMaxMigDeviceCount.RLock() + calls = mock.calls.GetMaxMigDeviceCount + mock.lockGetMaxMigDeviceCount.RUnlock() + return calls +} + +// GetMemoryInfo calls GetMemoryInfoFunc. +func (mock *DeviceMock) GetMemoryInfo() (Memory, Return) { + if mock.GetMemoryInfoFunc == nil { + panic("DeviceMock.GetMemoryInfoFunc: method is nil but Device.GetMemoryInfo was just called") + } + callInfo := struct { + }{} + mock.lockGetMemoryInfo.Lock() + mock.calls.GetMemoryInfo = append(mock.calls.GetMemoryInfo, callInfo) + mock.lockGetMemoryInfo.Unlock() + return mock.GetMemoryInfoFunc() +} + +// GetMemoryInfoCalls gets all the calls that were made to GetMemoryInfo. +// Check the length with: +// len(mockedDevice.GetMemoryInfoCalls()) +func (mock *DeviceMock) GetMemoryInfoCalls() []struct { +} { + var calls []struct { + } + mock.lockGetMemoryInfo.RLock() + calls = mock.calls.GetMemoryInfo + mock.lockGetMemoryInfo.RUnlock() + return calls +} + +// GetMigDeviceHandleByIndex calls GetMigDeviceHandleByIndexFunc. +func (mock *DeviceMock) GetMigDeviceHandleByIndex(Index int) (Device, Return) { + if mock.GetMigDeviceHandleByIndexFunc == nil { + panic("DeviceMock.GetMigDeviceHandleByIndexFunc: method is nil but Device.GetMigDeviceHandleByIndex was just called") + } + callInfo := struct { + Index int + }{ + Index: Index, + } + mock.lockGetMigDeviceHandleByIndex.Lock() + mock.calls.GetMigDeviceHandleByIndex = append(mock.calls.GetMigDeviceHandleByIndex, callInfo) + mock.lockGetMigDeviceHandleByIndex.Unlock() + return mock.GetMigDeviceHandleByIndexFunc(Index) +} + +// GetMigDeviceHandleByIndexCalls gets all the calls that were made to GetMigDeviceHandleByIndex. +// Check the length with: +// len(mockedDevice.GetMigDeviceHandleByIndexCalls()) +func (mock *DeviceMock) GetMigDeviceHandleByIndexCalls() []struct { + Index int +} { + var calls []struct { + Index int + } + mock.lockGetMigDeviceHandleByIndex.RLock() + calls = mock.calls.GetMigDeviceHandleByIndex + mock.lockGetMigDeviceHandleByIndex.RUnlock() + return calls +} + +// GetMigMode calls GetMigModeFunc. +func (mock *DeviceMock) GetMigMode() (int, int, Return) { + if mock.GetMigModeFunc == nil { + panic("DeviceMock.GetMigModeFunc: method is nil but Device.GetMigMode was just called") + } + callInfo := struct { + }{} + mock.lockGetMigMode.Lock() + mock.calls.GetMigMode = append(mock.calls.GetMigMode, callInfo) + mock.lockGetMigMode.Unlock() + return mock.GetMigModeFunc() +} + +// GetMigModeCalls gets all the calls that were made to GetMigMode. +// Check the length with: +// len(mockedDevice.GetMigModeCalls()) +func (mock *DeviceMock) GetMigModeCalls() []struct { +} { + var calls []struct { + } + mock.lockGetMigMode.RLock() + calls = mock.calls.GetMigMode + mock.lockGetMigMode.RUnlock() + return calls +} + +// GetMinorNumber calls GetMinorNumberFunc. +func (mock *DeviceMock) GetMinorNumber() (int, Return) { + if mock.GetMinorNumberFunc == nil { + panic("DeviceMock.GetMinorNumberFunc: method is nil but Device.GetMinorNumber was just called") + } + callInfo := struct { + }{} + mock.lockGetMinorNumber.Lock() + mock.calls.GetMinorNumber = append(mock.calls.GetMinorNumber, callInfo) + mock.lockGetMinorNumber.Unlock() + return mock.GetMinorNumberFunc() +} + +// GetMinorNumberCalls gets all the calls that were made to GetMinorNumber. +// Check the length with: +// len(mockedDevice.GetMinorNumberCalls()) +func (mock *DeviceMock) GetMinorNumberCalls() []struct { +} { + var calls []struct { + } + mock.lockGetMinorNumber.RLock() + calls = mock.calls.GetMinorNumber + mock.lockGetMinorNumber.RUnlock() + return calls +} + +// GetPciInfo calls GetPciInfoFunc. +func (mock *DeviceMock) GetPciInfo() (PciInfo, Return) { + if mock.GetPciInfoFunc == nil { + panic("DeviceMock.GetPciInfoFunc: method is nil but Device.GetPciInfo was just called") + } + callInfo := struct { + }{} + mock.lockGetPciInfo.Lock() + mock.calls.GetPciInfo = append(mock.calls.GetPciInfo, callInfo) + mock.lockGetPciInfo.Unlock() + return mock.GetPciInfoFunc() +} + +// GetPciInfoCalls gets all the calls that were made to GetPciInfo. +// Check the length with: +// len(mockedDevice.GetPciInfoCalls()) +func (mock *DeviceMock) GetPciInfoCalls() []struct { +} { + var calls []struct { + } + mock.lockGetPciInfo.RLock() + calls = mock.calls.GetPciInfo + mock.lockGetPciInfo.RUnlock() + return calls +} + +// GetUUID calls GetUUIDFunc. +func (mock *DeviceMock) GetUUID() (string, Return) { + if mock.GetUUIDFunc == nil { + panic("DeviceMock.GetUUIDFunc: method is nil but Device.GetUUID was just called") + } + callInfo := struct { + }{} + mock.lockGetUUID.Lock() + mock.calls.GetUUID = append(mock.calls.GetUUID, callInfo) + mock.lockGetUUID.Unlock() + return mock.GetUUIDFunc() +} + +// GetUUIDCalls gets all the calls that were made to GetUUID. +// Check the length with: +// len(mockedDevice.GetUUIDCalls()) +func (mock *DeviceMock) GetUUIDCalls() []struct { +} { + var calls []struct { + } + mock.lockGetUUID.RLock() + calls = mock.calls.GetUUID + mock.lockGetUUID.RUnlock() + return calls +} + +// IsMigDeviceHandle calls IsMigDeviceHandleFunc. +func (mock *DeviceMock) IsMigDeviceHandle() (bool, Return) { + if mock.IsMigDeviceHandleFunc == nil { + panic("DeviceMock.IsMigDeviceHandleFunc: method is nil but Device.IsMigDeviceHandle was just called") + } + callInfo := struct { + }{} + mock.lockIsMigDeviceHandle.Lock() + mock.calls.IsMigDeviceHandle = append(mock.calls.IsMigDeviceHandle, callInfo) + mock.lockIsMigDeviceHandle.Unlock() + return mock.IsMigDeviceHandleFunc() +} + +// IsMigDeviceHandleCalls gets all the calls that were made to IsMigDeviceHandle. +// Check the length with: +// len(mockedDevice.IsMigDeviceHandleCalls()) +func (mock *DeviceMock) IsMigDeviceHandleCalls() []struct { +} { + var calls []struct { + } + mock.lockIsMigDeviceHandle.RLock() + calls = mock.calls.IsMigDeviceHandle + mock.lockIsMigDeviceHandle.RUnlock() + return calls +} + +// SetMigMode calls SetMigModeFunc. +func (mock *DeviceMock) SetMigMode(Mode int) (Return, Return) { + if mock.SetMigModeFunc == nil { + panic("DeviceMock.SetMigModeFunc: method is nil but Device.SetMigMode was just called") + } + callInfo := struct { + Mode int + }{ + Mode: Mode, + } + mock.lockSetMigMode.Lock() + mock.calls.SetMigMode = append(mock.calls.SetMigMode, callInfo) + mock.lockSetMigMode.Unlock() + return mock.SetMigModeFunc(Mode) +} + +// SetMigModeCalls gets all the calls that were made to SetMigMode. +// Check the length with: +// len(mockedDevice.SetMigModeCalls()) +func (mock *DeviceMock) SetMigModeCalls() []struct { + Mode int +} { + var calls []struct { + Mode int + } + mock.lockSetMigMode.RLock() + calls = mock.calls.SetMigMode + mock.lockSetMigMode.RUnlock() + return calls +} diff --git a/pkg/nvml/gi.go b/pkg/nvml/gi.go new file mode 100644 index 0000000..cd775bc --- /dev/null +++ b/pkg/nvml/gi.go @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nvml + +import ( + "github.com/NVIDIA/go-nvml/pkg/nvml" +) + +type nvmlGpuInstance nvml.GpuInstance + +var _ GpuInstance = (*nvmlGpuInstance)(nil) + +// GetInfo returns info about a GPU Intsance +func (gi nvmlGpuInstance) GetInfo() (GpuInstanceInfo, Return) { + i, r := nvml.GpuInstance(gi).GetInfo() + info := GpuInstanceInfo{ + Device: nvmlDevice(i.Device), + Id: i.Id, + ProfileId: i.ProfileId, + Placement: GpuInstancePlacement(i.Placement), + } + return info, Return(r) +} + +// GetComputeInstanceProfileInfo returns info about a given Compute Instance profile +func (gi nvmlGpuInstance) GetComputeInstanceProfileInfo(profile int, engProfile int) (ComputeInstanceProfileInfo, Return) { + p, r := nvml.GpuInstance(gi).GetComputeInstanceProfileInfo(profile, engProfile) + return ComputeInstanceProfileInfo(p), Return(r) +} + +// CreateComputeInstance creates a Compute Instance within the GPU Instance +func (gi nvmlGpuInstance) CreateComputeInstance(info *ComputeInstanceProfileInfo) (ComputeInstance, Return) { + ci, r := nvml.GpuInstance(gi).CreateComputeInstance((*nvml.ComputeInstanceProfileInfo)(info)) + return nvmlComputeInstance(ci), Return(r) +} + +// GetComputeInstances returns the set of Compute Instances associated with a GPU Instance +func (gi nvmlGpuInstance) GetComputeInstances(info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) { + nvmlCis, r := nvml.GpuInstance(gi).GetComputeInstances((*nvml.ComputeInstanceProfileInfo)(info)) + var cis []ComputeInstance + for _, ci := range nvmlCis { + cis = append(cis, nvmlComputeInstance(ci)) + } + return cis, Return(r) +} + +// Destroy destroys a GPU Instance +func (gi nvmlGpuInstance) Destroy() Return { + r := nvml.GpuInstance(gi).Destroy() + return Return(r) +} diff --git a/pkg/nvml/gi_mock.go b/pkg/nvml/gi_mock.go new file mode 100644 index 0000000..7393a04 --- /dev/null +++ b/pkg/nvml/gi_mock.go @@ -0,0 +1,237 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package nvml + +import ( + "sync" +) + +// Ensure, that GpuInstanceMock does implement GpuInstance. +// If this is not the case, regenerate this file with moq. +var _ GpuInstance = &GpuInstanceMock{} + +// GpuInstanceMock is a mock implementation of GpuInstance. +// +// func TestSomethingThatUsesGpuInstance(t *testing.T) { +// +// // make and configure a mocked GpuInstance +// mockedGpuInstance := &GpuInstanceMock{ +// CreateComputeInstanceFunc: func(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return) { +// panic("mock out the CreateComputeInstance method") +// }, +// DestroyFunc: func() Return { +// panic("mock out the Destroy method") +// }, +// GetComputeInstanceProfileInfoFunc: func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) { +// panic("mock out the GetComputeInstanceProfileInfo method") +// }, +// GetComputeInstancesFunc: func(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) { +// panic("mock out the GetComputeInstances method") +// }, +// GetInfoFunc: func() (GpuInstanceInfo, Return) { +// panic("mock out the GetInfo method") +// }, +// } +// +// // use mockedGpuInstance in code that requires GpuInstance +// // and then make assertions. +// +// } +type GpuInstanceMock struct { + // CreateComputeInstanceFunc mocks the CreateComputeInstance method. + CreateComputeInstanceFunc func(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return) + + // DestroyFunc mocks the Destroy method. + DestroyFunc func() Return + + // GetComputeInstanceProfileInfoFunc mocks the GetComputeInstanceProfileInfo method. + GetComputeInstanceProfileInfoFunc func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) + + // GetComputeInstancesFunc mocks the GetComputeInstances method. + GetComputeInstancesFunc func(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) + + // GetInfoFunc mocks the GetInfo method. + GetInfoFunc func() (GpuInstanceInfo, Return) + + // calls tracks calls to the methods. + calls struct { + // CreateComputeInstance holds details about calls to the CreateComputeInstance method. + CreateComputeInstance []struct { + // Info is the Info argument value. + Info *ComputeInstanceProfileInfo + } + // Destroy holds details about calls to the Destroy method. + Destroy []struct { + } + // GetComputeInstanceProfileInfo holds details about calls to the GetComputeInstanceProfileInfo method. + GetComputeInstanceProfileInfo []struct { + // Profile is the Profile argument value. + Profile int + // EngProfile is the EngProfile argument value. + EngProfile int + } + // GetComputeInstances holds details about calls to the GetComputeInstances method. + GetComputeInstances []struct { + // Info is the Info argument value. + Info *ComputeInstanceProfileInfo + } + // GetInfo holds details about calls to the GetInfo method. + GetInfo []struct { + } + } + lockCreateComputeInstance sync.RWMutex + lockDestroy sync.RWMutex + lockGetComputeInstanceProfileInfo sync.RWMutex + lockGetComputeInstances sync.RWMutex + lockGetInfo sync.RWMutex +} + +// CreateComputeInstance calls CreateComputeInstanceFunc. +func (mock *GpuInstanceMock) CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return) { + if mock.CreateComputeInstanceFunc == nil { + panic("GpuInstanceMock.CreateComputeInstanceFunc: method is nil but GpuInstance.CreateComputeInstance was just called") + } + callInfo := struct { + Info *ComputeInstanceProfileInfo + }{ + Info: Info, + } + mock.lockCreateComputeInstance.Lock() + mock.calls.CreateComputeInstance = append(mock.calls.CreateComputeInstance, callInfo) + mock.lockCreateComputeInstance.Unlock() + return mock.CreateComputeInstanceFunc(Info) +} + +// CreateComputeInstanceCalls gets all the calls that were made to CreateComputeInstance. +// Check the length with: +// len(mockedGpuInstance.CreateComputeInstanceCalls()) +func (mock *GpuInstanceMock) CreateComputeInstanceCalls() []struct { + Info *ComputeInstanceProfileInfo +} { + var calls []struct { + Info *ComputeInstanceProfileInfo + } + mock.lockCreateComputeInstance.RLock() + calls = mock.calls.CreateComputeInstance + mock.lockCreateComputeInstance.RUnlock() + return calls +} + +// Destroy calls DestroyFunc. +func (mock *GpuInstanceMock) Destroy() Return { + if mock.DestroyFunc == nil { + panic("GpuInstanceMock.DestroyFunc: method is nil but GpuInstance.Destroy was just called") + } + callInfo := struct { + }{} + mock.lockDestroy.Lock() + mock.calls.Destroy = append(mock.calls.Destroy, callInfo) + mock.lockDestroy.Unlock() + return mock.DestroyFunc() +} + +// DestroyCalls gets all the calls that were made to Destroy. +// Check the length with: +// len(mockedGpuInstance.DestroyCalls()) +func (mock *GpuInstanceMock) DestroyCalls() []struct { +} { + var calls []struct { + } + mock.lockDestroy.RLock() + calls = mock.calls.Destroy + mock.lockDestroy.RUnlock() + return calls +} + +// GetComputeInstanceProfileInfo calls GetComputeInstanceProfileInfoFunc. +func (mock *GpuInstanceMock) GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) { + if mock.GetComputeInstanceProfileInfoFunc == nil { + panic("GpuInstanceMock.GetComputeInstanceProfileInfoFunc: method is nil but GpuInstance.GetComputeInstanceProfileInfo was just called") + } + callInfo := struct { + Profile int + EngProfile int + }{ + Profile: Profile, + EngProfile: EngProfile, + } + mock.lockGetComputeInstanceProfileInfo.Lock() + mock.calls.GetComputeInstanceProfileInfo = append(mock.calls.GetComputeInstanceProfileInfo, callInfo) + mock.lockGetComputeInstanceProfileInfo.Unlock() + return mock.GetComputeInstanceProfileInfoFunc(Profile, EngProfile) +} + +// GetComputeInstanceProfileInfoCalls gets all the calls that were made to GetComputeInstanceProfileInfo. +// Check the length with: +// len(mockedGpuInstance.GetComputeInstanceProfileInfoCalls()) +func (mock *GpuInstanceMock) GetComputeInstanceProfileInfoCalls() []struct { + Profile int + EngProfile int +} { + var calls []struct { + Profile int + EngProfile int + } + mock.lockGetComputeInstanceProfileInfo.RLock() + calls = mock.calls.GetComputeInstanceProfileInfo + mock.lockGetComputeInstanceProfileInfo.RUnlock() + return calls +} + +// GetComputeInstances calls GetComputeInstancesFunc. +func (mock *GpuInstanceMock) GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) { + if mock.GetComputeInstancesFunc == nil { + panic("GpuInstanceMock.GetComputeInstancesFunc: method is nil but GpuInstance.GetComputeInstances was just called") + } + callInfo := struct { + Info *ComputeInstanceProfileInfo + }{ + Info: Info, + } + mock.lockGetComputeInstances.Lock() + mock.calls.GetComputeInstances = append(mock.calls.GetComputeInstances, callInfo) + mock.lockGetComputeInstances.Unlock() + return mock.GetComputeInstancesFunc(Info) +} + +// GetComputeInstancesCalls gets all the calls that were made to GetComputeInstances. +// Check the length with: +// len(mockedGpuInstance.GetComputeInstancesCalls()) +func (mock *GpuInstanceMock) GetComputeInstancesCalls() []struct { + Info *ComputeInstanceProfileInfo +} { + var calls []struct { + Info *ComputeInstanceProfileInfo + } + mock.lockGetComputeInstances.RLock() + calls = mock.calls.GetComputeInstances + mock.lockGetComputeInstances.RUnlock() + return calls +} + +// GetInfo calls GetInfoFunc. +func (mock *GpuInstanceMock) GetInfo() (GpuInstanceInfo, Return) { + if mock.GetInfoFunc == nil { + panic("GpuInstanceMock.GetInfoFunc: method is nil but GpuInstance.GetInfo was just called") + } + callInfo := struct { + }{} + mock.lockGetInfo.Lock() + mock.calls.GetInfo = append(mock.calls.GetInfo, callInfo) + mock.lockGetInfo.Unlock() + return mock.GetInfoFunc() +} + +// GetInfoCalls gets all the calls that were made to GetInfo. +// Check the length with: +// len(mockedGpuInstance.GetInfoCalls()) +func (mock *GpuInstanceMock) GetInfoCalls() []struct { +} { + var calls []struct { + } + mock.lockGetInfo.RLock() + calls = mock.calls.GetInfo + mock.lockGetInfo.RUnlock() + return calls +} diff --git a/pkg/nvml/nvml.go b/pkg/nvml/nvml.go new file mode 100644 index 0000000..3ff684f --- /dev/null +++ b/pkg/nvml/nvml.go @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nvml + +import ( + "github.com/NVIDIA/go-nvml/pkg/nvml" +) + +type nvmlLib struct{} + +var _ Interface = (*nvmlLib)(nil) + +// New creates a new instance of the NVML Interface +func New() Interface { + return &nvmlLib{} +} + +// Init initializes an NVML Interface +func (n *nvmlLib) Init() Return { + return Return(nvml.Init()) +} + +// Shutdown shuts down an NVML Interface +func (n *nvmlLib) Shutdown() Return { + return Return(nvml.Shutdown()) +} + +// DeviceGetCount returns the total number of GPU Devices +func (n *nvmlLib) DeviceGetCount() (int, Return) { + c, r := nvml.DeviceGetCount() + return c, Return(r) +} + +// DeviceGetHandleByIndex returns a Device handle given its index +func (n *nvmlLib) DeviceGetHandleByIndex(index int) (Device, Return) { + d, r := nvml.DeviceGetHandleByIndex(index) + return nvmlDevice(d), Return(r) +} + +// DeviceGetHandleByUUID returns a Device handle given its UUID +func (n *nvmlLib) DeviceGetHandleByUUID(uuid string) (Device, Return) { + d, r := nvml.DeviceGetHandleByUUID(uuid) + return nvmlDevice(d), Return(r) +} + +// SystemGetDriverVersion returns the version of the installed NVIDIA driver +func (n *nvmlLib) SystemGetDriverVersion() (string, Return) { + v, r := nvml.SystemGetDriverVersion() + return v, Return(r) +} + +// ErrorString returns the error string associated with a given return value +func (n *nvmlLib) ErrorString(ret Return) string { + return nvml.ErrorString(nvml.Return(ret)) +} diff --git a/pkg/nvml/nvml_mock.go b/pkg/nvml/nvml_mock.go new file mode 100644 index 0000000..0efaf1d --- /dev/null +++ b/pkg/nvml/nvml_mock.go @@ -0,0 +1,303 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package nvml + +import ( + "sync" +) + +// Ensure, that InterfaceMock does implement Interface. +// If this is not the case, regenerate this file with moq. +var _ Interface = &InterfaceMock{} + +// InterfaceMock is a mock implementation of Interface. +// +// func TestSomethingThatUsesInterface(t *testing.T) { +// +// // make and configure a mocked Interface +// mockedInterface := &InterfaceMock{ +// DeviceGetCountFunc: func() (int, Return) { +// panic("mock out the DeviceGetCount method") +// }, +// DeviceGetHandleByIndexFunc: func(Index int) (Device, Return) { +// panic("mock out the DeviceGetHandleByIndex method") +// }, +// DeviceGetHandleByUUIDFunc: func(UUID string) (Device, Return) { +// panic("mock out the DeviceGetHandleByUUID method") +// }, +// ErrorStringFunc: func(r Return) string { +// panic("mock out the ErrorString method") +// }, +// InitFunc: func() Return { +// panic("mock out the Init method") +// }, +// ShutdownFunc: func() Return { +// panic("mock out the Shutdown method") +// }, +// SystemGetDriverVersionFunc: func() (string, Return) { +// panic("mock out the SystemGetDriverVersion method") +// }, +// } +// +// // use mockedInterface in code that requires Interface +// // and then make assertions. +// +// } +type InterfaceMock struct { + // DeviceGetCountFunc mocks the DeviceGetCount method. + DeviceGetCountFunc func() (int, Return) + + // DeviceGetHandleByIndexFunc mocks the DeviceGetHandleByIndex method. + DeviceGetHandleByIndexFunc func(Index int) (Device, Return) + + // DeviceGetHandleByUUIDFunc mocks the DeviceGetHandleByUUID method. + DeviceGetHandleByUUIDFunc func(UUID string) (Device, Return) + + // ErrorStringFunc mocks the ErrorString method. + ErrorStringFunc func(r Return) string + + // InitFunc mocks the Init method. + InitFunc func() Return + + // ShutdownFunc mocks the Shutdown method. + ShutdownFunc func() Return + + // SystemGetDriverVersionFunc mocks the SystemGetDriverVersion method. + SystemGetDriverVersionFunc func() (string, Return) + + // calls tracks calls to the methods. + calls struct { + // DeviceGetCount holds details about calls to the DeviceGetCount method. + DeviceGetCount []struct { + } + // DeviceGetHandleByIndex holds details about calls to the DeviceGetHandleByIndex method. + DeviceGetHandleByIndex []struct { + // Index is the Index argument value. + Index int + } + // DeviceGetHandleByUUID holds details about calls to the DeviceGetHandleByUUID method. + DeviceGetHandleByUUID []struct { + // UUID is the UUID argument value. + UUID string + } + // ErrorString holds details about calls to the ErrorString method. + ErrorString []struct { + // R is the r argument value. + R Return + } + // Init holds details about calls to the Init method. + Init []struct { + } + // Shutdown holds details about calls to the Shutdown method. + Shutdown []struct { + } + // SystemGetDriverVersion holds details about calls to the SystemGetDriverVersion method. + SystemGetDriverVersion []struct { + } + } + lockDeviceGetCount sync.RWMutex + lockDeviceGetHandleByIndex sync.RWMutex + lockDeviceGetHandleByUUID sync.RWMutex + lockErrorString sync.RWMutex + lockInit sync.RWMutex + lockShutdown sync.RWMutex + lockSystemGetDriverVersion sync.RWMutex +} + +// DeviceGetCount calls DeviceGetCountFunc. +func (mock *InterfaceMock) DeviceGetCount() (int, Return) { + if mock.DeviceGetCountFunc == nil { + panic("InterfaceMock.DeviceGetCountFunc: method is nil but Interface.DeviceGetCount was just called") + } + callInfo := struct { + }{} + mock.lockDeviceGetCount.Lock() + mock.calls.DeviceGetCount = append(mock.calls.DeviceGetCount, callInfo) + mock.lockDeviceGetCount.Unlock() + return mock.DeviceGetCountFunc() +} + +// DeviceGetCountCalls gets all the calls that were made to DeviceGetCount. +// Check the length with: +// len(mockedInterface.DeviceGetCountCalls()) +func (mock *InterfaceMock) DeviceGetCountCalls() []struct { +} { + var calls []struct { + } + mock.lockDeviceGetCount.RLock() + calls = mock.calls.DeviceGetCount + mock.lockDeviceGetCount.RUnlock() + return calls +} + +// DeviceGetHandleByIndex calls DeviceGetHandleByIndexFunc. +func (mock *InterfaceMock) DeviceGetHandleByIndex(Index int) (Device, Return) { + if mock.DeviceGetHandleByIndexFunc == nil { + panic("InterfaceMock.DeviceGetHandleByIndexFunc: method is nil but Interface.DeviceGetHandleByIndex was just called") + } + callInfo := struct { + Index int + }{ + Index: Index, + } + mock.lockDeviceGetHandleByIndex.Lock() + mock.calls.DeviceGetHandleByIndex = append(mock.calls.DeviceGetHandleByIndex, callInfo) + mock.lockDeviceGetHandleByIndex.Unlock() + return mock.DeviceGetHandleByIndexFunc(Index) +} + +// DeviceGetHandleByIndexCalls gets all the calls that were made to DeviceGetHandleByIndex. +// Check the length with: +// len(mockedInterface.DeviceGetHandleByIndexCalls()) +func (mock *InterfaceMock) DeviceGetHandleByIndexCalls() []struct { + Index int +} { + var calls []struct { + Index int + } + mock.lockDeviceGetHandleByIndex.RLock() + calls = mock.calls.DeviceGetHandleByIndex + mock.lockDeviceGetHandleByIndex.RUnlock() + return calls +} + +// DeviceGetHandleByUUID calls DeviceGetHandleByUUIDFunc. +func (mock *InterfaceMock) DeviceGetHandleByUUID(UUID string) (Device, Return) { + if mock.DeviceGetHandleByUUIDFunc == nil { + panic("InterfaceMock.DeviceGetHandleByUUIDFunc: method is nil but Interface.DeviceGetHandleByUUID was just called") + } + callInfo := struct { + UUID string + }{ + UUID: UUID, + } + mock.lockDeviceGetHandleByUUID.Lock() + mock.calls.DeviceGetHandleByUUID = append(mock.calls.DeviceGetHandleByUUID, callInfo) + mock.lockDeviceGetHandleByUUID.Unlock() + return mock.DeviceGetHandleByUUIDFunc(UUID) +} + +// DeviceGetHandleByUUIDCalls gets all the calls that were made to DeviceGetHandleByUUID. +// Check the length with: +// len(mockedInterface.DeviceGetHandleByUUIDCalls()) +func (mock *InterfaceMock) DeviceGetHandleByUUIDCalls() []struct { + UUID string +} { + var calls []struct { + UUID string + } + mock.lockDeviceGetHandleByUUID.RLock() + calls = mock.calls.DeviceGetHandleByUUID + mock.lockDeviceGetHandleByUUID.RUnlock() + return calls +} + +// ErrorString calls ErrorStringFunc. +func (mock *InterfaceMock) ErrorString(r Return) string { + if mock.ErrorStringFunc == nil { + panic("InterfaceMock.ErrorStringFunc: method is nil but Interface.ErrorString was just called") + } + callInfo := struct { + R Return + }{ + R: r, + } + mock.lockErrorString.Lock() + mock.calls.ErrorString = append(mock.calls.ErrorString, callInfo) + mock.lockErrorString.Unlock() + return mock.ErrorStringFunc(r) +} + +// ErrorStringCalls gets all the calls that were made to ErrorString. +// Check the length with: +// len(mockedInterface.ErrorStringCalls()) +func (mock *InterfaceMock) ErrorStringCalls() []struct { + R Return +} { + var calls []struct { + R Return + } + mock.lockErrorString.RLock() + calls = mock.calls.ErrorString + mock.lockErrorString.RUnlock() + return calls +} + +// Init calls InitFunc. +func (mock *InterfaceMock) Init() Return { + if mock.InitFunc == nil { + panic("InterfaceMock.InitFunc: method is nil but Interface.Init was just called") + } + callInfo := struct { + }{} + mock.lockInit.Lock() + mock.calls.Init = append(mock.calls.Init, callInfo) + mock.lockInit.Unlock() + return mock.InitFunc() +} + +// InitCalls gets all the calls that were made to Init. +// Check the length with: +// len(mockedInterface.InitCalls()) +func (mock *InterfaceMock) InitCalls() []struct { +} { + var calls []struct { + } + mock.lockInit.RLock() + calls = mock.calls.Init + mock.lockInit.RUnlock() + return calls +} + +// Shutdown calls ShutdownFunc. +func (mock *InterfaceMock) Shutdown() Return { + if mock.ShutdownFunc == nil { + panic("InterfaceMock.ShutdownFunc: method is nil but Interface.Shutdown was just called") + } + callInfo := struct { + }{} + mock.lockShutdown.Lock() + mock.calls.Shutdown = append(mock.calls.Shutdown, callInfo) + mock.lockShutdown.Unlock() + return mock.ShutdownFunc() +} + +// ShutdownCalls gets all the calls that were made to Shutdown. +// Check the length with: +// len(mockedInterface.ShutdownCalls()) +func (mock *InterfaceMock) ShutdownCalls() []struct { +} { + var calls []struct { + } + mock.lockShutdown.RLock() + calls = mock.calls.Shutdown + mock.lockShutdown.RUnlock() + return calls +} + +// SystemGetDriverVersion calls SystemGetDriverVersionFunc. +func (mock *InterfaceMock) SystemGetDriverVersion() (string, Return) { + if mock.SystemGetDriverVersionFunc == nil { + panic("InterfaceMock.SystemGetDriverVersionFunc: method is nil but Interface.SystemGetDriverVersion was just called") + } + callInfo := struct { + }{} + mock.lockSystemGetDriverVersion.Lock() + mock.calls.SystemGetDriverVersion = append(mock.calls.SystemGetDriverVersion, callInfo) + mock.lockSystemGetDriverVersion.Unlock() + return mock.SystemGetDriverVersionFunc() +} + +// SystemGetDriverVersionCalls gets all the calls that were made to SystemGetDriverVersion. +// Check the length with: +// len(mockedInterface.SystemGetDriverVersionCalls()) +func (mock *InterfaceMock) SystemGetDriverVersionCalls() []struct { +} { + var calls []struct { + } + mock.lockSystemGetDriverVersion.RLock() + calls = mock.calls.SystemGetDriverVersion + mock.lockSystemGetDriverVersion.RUnlock() + return calls +} diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go new file mode 100644 index 0000000..adbe242 --- /dev/null +++ b/pkg/nvml/types.go @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nvml + +import ( + "github.com/NVIDIA/go-nvml/pkg/nvml" +) + +// Return defines an NVML return type +type Return nvml.Return + +//go:generate moq -out nvml_mock.go . Interface +// Interface defines the functions implemented by an NVML library +type Interface interface { + Init() Return + Shutdown() Return + DeviceGetCount() (int, Return) + DeviceGetHandleByIndex(Index int) (Device, Return) + DeviceGetHandleByUUID(UUID string) (Device, Return) + SystemGetDriverVersion() (string, Return) + ErrorString(r Return) string +} + +//go:generate moq -out device_mock.go . Device +// Device defines the functions implemented by an NVML device +type Device interface { + GetIndex() (int, Return) + GetPciInfo() (PciInfo, Return) + GetMemoryInfo() (Memory, Return) + GetUUID() (string, Return) + GetMinorNumber() (int, Return) + IsMigDeviceHandle() (bool, Return) + GetDeviceHandleFromMigDeviceHandle() (Device, Return) + SetMigMode(Mode int) (Return, Return) + GetMigMode() (int, int, Return) + GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return) + GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) + GetMaxMigDeviceCount() (int, Return) + GetMigDeviceHandleByIndex(Index int) (Device, Return) + GetGpuInstanceId() (int, Return) + GetComputeInstanceId() (int, Return) +} + +//go:generate moq -out gi_mock.go . GpuInstance +// GpuInstance defines the functions implemented by a GpuInstance +type GpuInstance interface { + GetInfo() (GpuInstanceInfo, Return) + GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) + CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return) + GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) + Destroy() Return +} + +//go:generate moq -out ci_mock.go . ComputeInstance +// ComputeInstance defines the functions implemented by a ComputeInstance +type ComputeInstance interface { + GetInfo() (ComputeInstanceInfo, Return) + Destroy() Return +} + +// GpuInstanceInfo holds info about a GPU Instance +type GpuInstanceInfo struct { + Device Device + Id uint32 + ProfileId uint32 + Placement GpuInstancePlacement +} + +// ComputeInstanceInfo holds info about a Compute Instance +type ComputeInstanceInfo struct { + Device Device + GpuInstance GpuInstance + Id uint32 + ProfileId uint32 + Placement ComputeInstancePlacement +} + +// Memory holds info about GPU device memory +type Memory nvml.Memory + +//PciInfo holds info about the PCI connections of a GPU dvice +type PciInfo nvml.PciInfo + +// GpuInstanceProfileInfo holds info about a GPU Instance Profile +type GpuInstanceProfileInfo nvml.GpuInstanceProfileInfo + +// GpuInstancePlacement holds placement info about a GPU Instance +type GpuInstancePlacement nvml.GpuInstancePlacement + +// ComputeInstanceProfileInfo holds info about a Compute Instance Profile +type ComputeInstanceProfileInfo nvml.ComputeInstanceProfileInfo + +// ComputeInstancePlacement holds placement info about a Compute Instance +type ComputeInstancePlacement nvml.ComputeInstancePlacement