diff --git a/pkg/nvml/nvml.go b/pkg/nvml/nvml.go index 67c1b23..4fe1f75 100644 --- a/pkg/nvml/nvml.go +++ b/pkg/nvml/nvml.go @@ -92,6 +92,12 @@ func (n *nvmlLib) SystemGetDriverVersion() (string, Return) { return v, Return(r) } +// SystemGetCudaDriverVersion returns the version of CUDA associated with the NVIDIA driver +func (n *nvmlLib) SystemGetCudaDriverVersion() (int, Return) { + v, r := nvml.SystemGetCudaDriverVersion() + return v, Return(r) +} + // ErrorString returns the error string associated with a given return value func (n *nvmlLib) ErrorString(ret Return) string { return nvml.ErrorString(nvml.Return(ret)) diff --git a/pkg/nvml/nvml_mock.go b/pkg/nvml/nvml_mock.go index 0efaf1d..56b50a9 100644 --- a/pkg/nvml/nvml_mock.go +++ b/pkg/nvml/nvml_mock.go @@ -35,6 +35,9 @@ var _ Interface = &InterfaceMock{} // ShutdownFunc: func() Return { // panic("mock out the Shutdown method") // }, +// SystemGetCudaDriverVersionFunc: func() (int, Return) { +// panic("mock out the SystemGetCudaDriverVersion method") +// }, // SystemGetDriverVersionFunc: func() (string, Return) { // panic("mock out the SystemGetDriverVersion method") // }, @@ -63,6 +66,9 @@ type InterfaceMock struct { // ShutdownFunc mocks the Shutdown method. ShutdownFunc func() Return + // SystemGetCudaDriverVersionFunc mocks the SystemGetCudaDriverVersion method. + SystemGetCudaDriverVersionFunc func() (int, Return) + // SystemGetDriverVersionFunc mocks the SystemGetDriverVersion method. SystemGetDriverVersionFunc func() (string, Return) @@ -92,17 +98,21 @@ type InterfaceMock struct { // Shutdown holds details about calls to the Shutdown method. Shutdown []struct { } + // SystemGetCudaDriverVersion holds details about calls to the SystemGetCudaDriverVersion method. + SystemGetCudaDriverVersion []struct { + } // SystemGetDriverVersion holds details about calls to the SystemGetDriverVersion method. SystemGetDriverVersion []struct { } } - lockDeviceGetCount sync.RWMutex - lockDeviceGetHandleByIndex sync.RWMutex - lockDeviceGetHandleByUUID sync.RWMutex - lockErrorString sync.RWMutex - lockInit sync.RWMutex - lockShutdown sync.RWMutex - lockSystemGetDriverVersion sync.RWMutex + lockDeviceGetCount sync.RWMutex + lockDeviceGetHandleByIndex sync.RWMutex + lockDeviceGetHandleByUUID sync.RWMutex + lockErrorString sync.RWMutex + lockInit sync.RWMutex + lockShutdown sync.RWMutex + lockSystemGetCudaDriverVersion sync.RWMutex + lockSystemGetDriverVersion sync.RWMutex } // DeviceGetCount calls DeviceGetCountFunc. @@ -276,6 +286,32 @@ func (mock *InterfaceMock) ShutdownCalls() []struct { return calls } +// SystemGetCudaDriverVersion calls SystemGetCudaDriverVersionFunc. +func (mock *InterfaceMock) SystemGetCudaDriverVersion() (int, Return) { + if mock.SystemGetCudaDriverVersionFunc == nil { + panic("InterfaceMock.SystemGetCudaDriverVersionFunc: method is nil but Interface.SystemGetCudaDriverVersion was just called") + } + callInfo := struct { + }{} + mock.lockSystemGetCudaDriverVersion.Lock() + mock.calls.SystemGetCudaDriverVersion = append(mock.calls.SystemGetCudaDriverVersion, callInfo) + mock.lockSystemGetCudaDriverVersion.Unlock() + return mock.SystemGetCudaDriverVersionFunc() +} + +// SystemGetCudaDriverVersionCalls gets all the calls that were made to SystemGetCudaDriverVersion. +// Check the length with: +// len(mockedInterface.SystemGetCudaDriverVersionCalls()) +func (mock *InterfaceMock) SystemGetCudaDriverVersionCalls() []struct { +} { + var calls []struct { + } + mock.lockSystemGetCudaDriverVersion.RLock() + calls = mock.calls.SystemGetCudaDriverVersion + mock.lockSystemGetCudaDriverVersion.RUnlock() + return calls +} + // SystemGetDriverVersion calls SystemGetDriverVersionFunc. func (mock *InterfaceMock) SystemGetDriverVersion() (string, Return) { if mock.SystemGetDriverVersionFunc == nil { diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index 31ee7dd..47a7eb3 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -20,8 +20,9 @@ import ( "github.com/NVIDIA/go-nvml/pkg/nvml" ) -//go:generate moq -out nvml_mock.go . Interface // Interface defines the functions implemented by an NVML library +// +//go:generate moq -out nvml_mock.go . Interface type Interface interface { Init() Return Shutdown() Return @@ -29,11 +30,13 @@ type Interface interface { DeviceGetHandleByIndex(Index int) (Device, Return) DeviceGetHandleByUUID(UUID string) (Device, Return) SystemGetDriverVersion() (string, Return) + SystemGetCudaDriverVersion() (int, Return) ErrorString(r Return) string } -//go:generate moq -out device_mock.go . Device // Device defines the functions implemented by an NVML device +// +//go:generate moq -out device_mock.go . Device type Device interface { GetIndex() (int, Return) GetPciInfo() (PciInfo, Return) @@ -52,8 +55,9 @@ type Device interface { GetComputeInstanceId() (int, Return) } -//go:generate moq -out gi_mock.go . GpuInstance // GpuInstance defines the functions implemented by a GpuInstance +// +//go:generate moq -out gi_mock.go . GpuInstance type GpuInstance interface { GetInfo() (GpuInstanceInfo, Return) GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) @@ -62,8 +66,9 @@ type GpuInstance interface { Destroy() Return } -//go:generate moq -out ci_mock.go . ComputeInstance // ComputeInstance defines the functions implemented by a ComputeInstance +// +//go:generate moq -out ci_mock.go . ComputeInstance type ComputeInstance interface { GetInfo() (ComputeInstanceInfo, Return) Destroy() Return @@ -92,7 +97,7 @@ type Return nvml.Return // Memory holds info about GPU device memory type Memory nvml.Memory -//PciInfo holds info about the PCI connections of a GPU dvice +// PciInfo holds info about the PCI connections of a GPU dvice type PciInfo nvml.PciInfo // GpuInstanceProfileInfo holds info about a GPU Instance Profile