From 80d61efe5db954153a632da194f268ffcdea01c3 Mon Sep 17 00:00:00 2001
From: Evan Lezar <elezar@nvidia.com>
Date: Tue, 14 Nov 2023 12:45:59 +0100
Subject: [PATCH] Add functions related to NVLink info

Signed-off-by: Evan Lezar <elezar@nvidia.com>
---
 pkg/nvml/consts.go      |  21 +++++++
 pkg/nvml/device.go      |  24 ++++++++
 pkg/nvml/device_mock.go | 132 ++++++++++++++++++++++++++++++++++++++++
 pkg/nvml/types.go       |   9 +++
 4 files changed, 186 insertions(+)

diff --git a/pkg/nvml/consts.go b/pkg/nvml/consts.go
index c9b85de..d6a7ee3 100644
--- a/pkg/nvml/consts.go
+++ b/pkg/nvml/consts.go
@@ -20,6 +20,11 @@ import (
 	"github.com/NVIDIA/go-nvml/pkg/nvml"
 )
 
+// General untyped constants
+const (
+	NVLINK_MAX_LINKS = nvml.NVLINK_MAX_LINKS
+)
+
 // Return constants
 const (
 	SUCCESS                       = Return(nvml.SUCCESS)
@@ -131,3 +136,19 @@ const (
 	EventTypeSingleBitEccError = nvml.EventTypeSingleBitEccError
 	EventTypeDoubleBitEccError = nvml.EventTypeDoubleBitEccError
 )
+
+// GPU Topology enumeration
+const (
+	TOPOLOGY_INTERNAL   = GpuTopologyLevel(nvml.TOPOLOGY_INTERNAL)
+	TOPOLOGY_SINGLE     = GpuTopologyLevel(nvml.TOPOLOGY_SINGLE)
+	TOPOLOGY_MULTIPLE   = GpuTopologyLevel(nvml.TOPOLOGY_MULTIPLE)
+	TOPOLOGY_HOSTBRIDGE = GpuTopologyLevel(nvml.TOPOLOGY_HOSTBRIDGE)
+	TOPOLOGY_NODE       = GpuTopologyLevel(nvml.TOPOLOGY_NODE)
+	TOPOLOGY_SYSTEM     = GpuTopologyLevel(nvml.TOPOLOGY_SYSTEM)
+)
+
+// Generic enable/disable constants
+const (
+	FEATURE_DISABLED = EnableState(nvml.FEATURE_DISABLED)
+	FEATURE_ENABLED  = EnableState(nvml.FEATURE_ENABLED)
+)
diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go
index 3c318a7..3901bb7 100644
--- a/pkg/nvml/device.go
+++ b/pkg/nvml/device.go
@@ -178,3 +178,27 @@ func (d nvmlDevice) GetSupportedEventTypes() (uint64, Return) {
 	e, r := nvml.Device(d).GetSupportedEventTypes()
 	return e, Return(r)
 }
+
+// GetTopologyCommonAncestor retrieves the common ancestor for two devices.
+func (d nvmlDevice) GetTopologyCommonAncestor(o Device) (GpuTopologyLevel, Return) {
+	other, ok := o.(nvmlDevice)
+	if !ok {
+		return 0, ERROR_INVALID_ARGUMENT
+	}
+
+	l, r := nvml.Device(d).GetTopologyCommonAncestor(nvml.Device(other))
+	return GpuTopologyLevel(l), Return(r)
+}
+
+// GetNvLinkState retrieves the state of the device's NvLink for the link specified.
+func (d nvmlDevice) GetNvLinkState(link int) (EnableState, Return) {
+	s, r := nvml.Device(d).GetNvLinkState(link)
+	return EnableState(s), Return(r)
+}
+
+// GetNvLinkRemotePciInfo retrieves the PCI information for the remote node on a NvLink link.
+// Note: pciSubSystemId is not filled in this function and is indeterminate.
+func (d nvmlDevice) GetNvLinkRemotePciInfo(link int) (PciInfo, Return) {
+	p, r := nvml.Device(d).GetNvLinkRemotePciInfo(link)
+	return PciInfo(p), Return(r)
+}
diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go
index 34e563c..17dfe25 100644
--- a/pkg/nvml/device_mock.go
+++ b/pkg/nvml/device_mock.go
@@ -74,12 +74,21 @@ var _ Device = &DeviceMock{}
 //			GetNameFunc: func() (string, Return) {
 //				panic("mock out the GetName method")
 //			},
+//			GetNvLinkRemotePciInfoFunc: func(n int) (PciInfo, Return) {
+//				panic("mock out the GetNvLinkRemotePciInfo method")
+//			},
+//			GetNvLinkStateFunc: func(n int) (EnableState, Return) {
+//				panic("mock out the GetNvLinkState method")
+//			},
 //			GetPciInfoFunc: func() (PciInfo, Return) {
 //				panic("mock out the GetPciInfo method")
 //			},
 //			GetSupportedEventTypesFunc: func() (uint64, Return) {
 //				panic("mock out the GetSupportedEventTypes method")
 //			},
+//			GetTopologyCommonAncestorFunc: func(device Device) (GpuTopologyLevel, Return) {
+//				panic("mock out the GetTopologyCommonAncestor method")
+//			},
 //			GetUUIDFunc: func() (string, Return) {
 //				panic("mock out the GetUUID method")
 //			},
@@ -156,12 +165,21 @@ type DeviceMock struct {
 	// GetNameFunc mocks the GetName method.
 	GetNameFunc func() (string, Return)
 
+	// GetNvLinkRemotePciInfoFunc mocks the GetNvLinkRemotePciInfo method.
+	GetNvLinkRemotePciInfoFunc func(n int) (PciInfo, Return)
+
+	// GetNvLinkStateFunc mocks the GetNvLinkState method.
+	GetNvLinkStateFunc func(n int) (EnableState, Return)
+
 	// GetPciInfoFunc mocks the GetPciInfo method.
 	GetPciInfoFunc func() (PciInfo, Return)
 
 	// GetSupportedEventTypesFunc mocks the GetSupportedEventTypes method.
 	GetSupportedEventTypesFunc func() (uint64, Return)
 
+	// GetTopologyCommonAncestorFunc mocks the GetTopologyCommonAncestor method.
+	GetTopologyCommonAncestorFunc func(device Device) (GpuTopologyLevel, Return)
+
 	// GetUUIDFunc mocks the GetUUID method.
 	GetUUIDFunc func() (string, Return)
 
@@ -247,12 +265,27 @@ type DeviceMock struct {
 		// GetName holds details about calls to the GetName method.
 		GetName []struct {
 		}
+		// GetNvLinkRemotePciInfo holds details about calls to the GetNvLinkRemotePciInfo method.
+		GetNvLinkRemotePciInfo []struct {
+			// N is the n argument value.
+			N int
+		}
+		// GetNvLinkState holds details about calls to the GetNvLinkState method.
+		GetNvLinkState []struct {
+			// N is the n argument value.
+			N int
+		}
 		// GetPciInfo holds details about calls to the GetPciInfo method.
 		GetPciInfo []struct {
 		}
 		// GetSupportedEventTypes holds details about calls to the GetSupportedEventTypes method.
 		GetSupportedEventTypes []struct {
 		}
+		// GetTopologyCommonAncestor holds details about calls to the GetTopologyCommonAncestor method.
+		GetTopologyCommonAncestor []struct {
+			// Device is the device argument value.
+			Device Device
+		}
 		// GetUUID holds details about calls to the GetUUID method.
 		GetUUID []struct {
 		}
@@ -291,8 +324,11 @@ type DeviceMock struct {
 	lockGetMigMode                         sync.RWMutex
 	lockGetMinorNumber                     sync.RWMutex
 	lockGetName                            sync.RWMutex
+	lockGetNvLinkRemotePciInfo             sync.RWMutex
+	lockGetNvLinkState                     sync.RWMutex
 	lockGetPciInfo                         sync.RWMutex
 	lockGetSupportedEventTypes             sync.RWMutex
+	lockGetTopologyCommonAncestor          sync.RWMutex
 	lockGetUUID                            sync.RWMutex
 	lockIsMigDeviceHandle                  sync.RWMutex
 	lockRegisterEvents                     sync.RWMutex
@@ -846,6 +882,70 @@ func (mock *DeviceMock) GetNameCalls() []struct {
 	return calls
 }
 
+// GetNvLinkRemotePciInfo calls GetNvLinkRemotePciInfoFunc.
+func (mock *DeviceMock) GetNvLinkRemotePciInfo(n int) (PciInfo, Return) {
+	if mock.GetNvLinkRemotePciInfoFunc == nil {
+		panic("DeviceMock.GetNvLinkRemotePciInfoFunc: method is nil but Device.GetNvLinkRemotePciInfo was just called")
+	}
+	callInfo := struct {
+		N int
+	}{
+		N: n,
+	}
+	mock.lockGetNvLinkRemotePciInfo.Lock()
+	mock.calls.GetNvLinkRemotePciInfo = append(mock.calls.GetNvLinkRemotePciInfo, callInfo)
+	mock.lockGetNvLinkRemotePciInfo.Unlock()
+	return mock.GetNvLinkRemotePciInfoFunc(n)
+}
+
+// GetNvLinkRemotePciInfoCalls gets all the calls that were made to GetNvLinkRemotePciInfo.
+// Check the length with:
+//
+//	len(mockedDevice.GetNvLinkRemotePciInfoCalls())
+func (mock *DeviceMock) GetNvLinkRemotePciInfoCalls() []struct {
+	N int
+} {
+	var calls []struct {
+		N int
+	}
+	mock.lockGetNvLinkRemotePciInfo.RLock()
+	calls = mock.calls.GetNvLinkRemotePciInfo
+	mock.lockGetNvLinkRemotePciInfo.RUnlock()
+	return calls
+}
+
+// GetNvLinkState calls GetNvLinkStateFunc.
+func (mock *DeviceMock) GetNvLinkState(n int) (EnableState, Return) {
+	if mock.GetNvLinkStateFunc == nil {
+		panic("DeviceMock.GetNvLinkStateFunc: method is nil but Device.GetNvLinkState was just called")
+	}
+	callInfo := struct {
+		N int
+	}{
+		N: n,
+	}
+	mock.lockGetNvLinkState.Lock()
+	mock.calls.GetNvLinkState = append(mock.calls.GetNvLinkState, callInfo)
+	mock.lockGetNvLinkState.Unlock()
+	return mock.GetNvLinkStateFunc(n)
+}
+
+// GetNvLinkStateCalls gets all the calls that were made to GetNvLinkState.
+// Check the length with:
+//
+//	len(mockedDevice.GetNvLinkStateCalls())
+func (mock *DeviceMock) GetNvLinkStateCalls() []struct {
+	N int
+} {
+	var calls []struct {
+		N int
+	}
+	mock.lockGetNvLinkState.RLock()
+	calls = mock.calls.GetNvLinkState
+	mock.lockGetNvLinkState.RUnlock()
+	return calls
+}
+
 // GetPciInfo calls GetPciInfoFunc.
 func (mock *DeviceMock) GetPciInfo() (PciInfo, Return) {
 	if mock.GetPciInfoFunc == nil {
@@ -900,6 +1000,38 @@ func (mock *DeviceMock) GetSupportedEventTypesCalls() []struct {
 	return calls
 }
 
+// GetTopologyCommonAncestor calls GetTopologyCommonAncestorFunc.
+func (mock *DeviceMock) GetTopologyCommonAncestor(device Device) (GpuTopologyLevel, Return) {
+	if mock.GetTopologyCommonAncestorFunc == nil {
+		panic("DeviceMock.GetTopologyCommonAncestorFunc: method is nil but Device.GetTopologyCommonAncestor was just called")
+	}
+	callInfo := struct {
+		Device Device
+	}{
+		Device: device,
+	}
+	mock.lockGetTopologyCommonAncestor.Lock()
+	mock.calls.GetTopologyCommonAncestor = append(mock.calls.GetTopologyCommonAncestor, callInfo)
+	mock.lockGetTopologyCommonAncestor.Unlock()
+	return mock.GetTopologyCommonAncestorFunc(device)
+}
+
+// GetTopologyCommonAncestorCalls gets all the calls that were made to GetTopologyCommonAncestor.
+// Check the length with:
+//
+//	len(mockedDevice.GetTopologyCommonAncestorCalls())
+func (mock *DeviceMock) GetTopologyCommonAncestorCalls() []struct {
+	Device Device
+} {
+	var calls []struct {
+		Device Device
+	}
+	mock.lockGetTopologyCommonAncestor.RLock()
+	calls = mock.calls.GetTopologyCommonAncestor
+	mock.lockGetTopologyCommonAncestor.RUnlock()
+	return calls
+}
+
 // GetUUID calls GetUUIDFunc.
 func (mock *DeviceMock) GetUUID() (string, Return) {
 	if mock.GetUUIDFunc == nil {
diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go
index 2aaaf08..ecbe4e6 100644
--- a/pkg/nvml/types.go
+++ b/pkg/nvml/types.go
@@ -59,8 +59,11 @@ type Device interface {
 	GetMigMode() (int, int, Return)
 	GetMinorNumber() (int, Return)
 	GetName() (string, Return)
+	GetNvLinkRemotePciInfo(int) (PciInfo, Return)
+	GetNvLinkState(int) (EnableState, Return)
 	GetPciInfo() (PciInfo, Return)
 	GetSupportedEventTypes() (uint64, Return)
+	GetTopologyCommonAncestor(Device) (GpuTopologyLevel, Return)
 	GetUUID() (string, Return)
 	IsMigDeviceHandle() (bool, Return)
 	RegisterEvents(uint64, EventSet) Return
@@ -145,3 +148,9 @@ type DeviceArchitecture nvml.DeviceArchitecture
 
 // BrandType represents the brand of a GPU device
 type BrandType nvml.BrandType
+
+// GpuTopologyLevel represents level relationships within a system between two GPUs
+type GpuTopologyLevel nvml.GpuTopologyLevel
+
+// EnableState represents a generic enable/disable enum
+type EnableState nvml.EnableState