diff --git a/go.mod b/go.mod index 5173b6e7..096c7be9 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( github.com/NVIDIA/go-nvlib v0.3.0 - github.com/NVIDIA/go-nvml v0.12.0-5 + github.com/NVIDIA/go-nvml v0.12.0-6 github.com/fsnotify/fsnotify v1.7.0 github.com/opencontainers/runtime-spec v1.2.0 github.com/pelletier/go-toml v1.9.5 diff --git a/go.sum b/go.sum index 80d9b3d5..a5b90c3c 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/NVIDIA/go-nvlib v0.3.0 h1:vd7jSOthJTqzqIWZrv317xDr1+Mnjoy5X4N69W9YwQM= github.com/NVIDIA/go-nvlib v0.3.0/go.mod h1:NasUuId9hYFvwzuOHCu9F2X6oTU2tG0JHTfbJYuDAbA= -github.com/NVIDIA/go-nvml v0.12.0-5 h1:4DYsngBqJEAEj+/RFmBZ43Q3ymoR3tyS0oBuJk12Fag= -github.com/NVIDIA/go-nvml v0.12.0-5/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ= +github.com/NVIDIA/go-nvml v0.12.0-6 h1:FJYc2KrpvX+VOC/8QQvMiQMmZ/nPMRpdJO/Ik4xfcr0= +github.com/NVIDIA/go-nvml v0.12.0-6/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= diff --git a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go index 7ee5e551..7604d39f 100644 --- a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go +++ b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go @@ -15,9 +15,48 @@ package nvml import ( + "fmt" + "reflect" "unsafe" ) +// nvmlDeviceHandle attempts to convert a device d to an nvmlDevice. +// This is required for functions such as GetTopologyCommonAncestor which +// accept Device arguments that need to be passed to internal nvml* functions +// as nvmlDevice parameters. +func nvmlDeviceHandle(d Device) nvmlDevice { + var helper func(val reflect.Value) nvmlDevice + helper = func(val reflect.Value) nvmlDevice { + if val.Kind() == reflect.Interface { + val = val.Elem() + } + + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Type() == reflect.TypeOf(nvmlDevice{}) { + return val.Interface().(nvmlDevice) + } + + if val.Kind() != reflect.Struct { + panic(fmt.Errorf("unable to convert non-struct type %v to nvmlDevice", val.Kind())) + } + + for i := 0; i < val.Type().NumField(); i++ { + if !val.Type().Field(i).Anonymous { + continue + } + if !val.Field(i).Type().Implements(reflect.TypeOf((*Device)(nil)).Elem()) { + continue + } + return helper(val.Field(i)) + } + panic(fmt.Errorf("unable to convert %T to nvmlDevice", d)) + } + return helper(reflect.ValueOf(d)) +} + // EccBitType type EccBitType = MemoryErrorType @@ -220,10 +259,13 @@ func (l *library) DeviceGetTopologyCommonAncestor(device1 Device, device2 Device func (device1 nvmlDevice) GetTopologyCommonAncestor(device2 Device) (GpuTopologyLevel, Return) { var pathInfo GpuTopologyLevel - ret := nvmlDeviceGetTopologyCommonAncestor(device1, device2.(nvmlDevice), &pathInfo) + ret := nvmlDeviceGetTopologyCommonAncestorStub(device1, nvmlDeviceHandle(device2), &pathInfo) return pathInfo, ret } +// nvmlDeviceGetTopologyCommonAncestorStub allows us to override this for testing. +var nvmlDeviceGetTopologyCommonAncestorStub = nvmlDeviceGetTopologyCommonAncestor + // nvml.DeviceGetTopologyNearestGpus() func (l *library) DeviceGetTopologyNearestGpus(device Device, level GpuTopologyLevel) ([]Device, Return) { return device.GetTopologyNearestGpus(level) @@ -250,7 +292,7 @@ func (l *library) DeviceGetP2PStatus(device1 Device, device2 Device, p2pIndex Gp func (device1 nvmlDevice) GetP2PStatus(device2 Device, p2pIndex GpuP2PCapsIndex) (GpuP2PStatus, Return) { var p2pStatus GpuP2PStatus - ret := nvmlDeviceGetP2PStatus(device1, device2.(nvmlDevice), p2pIndex, &p2pStatus) + ret := nvmlDeviceGetP2PStatus(device1, nvmlDeviceHandle(device2), p2pIndex, &p2pStatus) return p2pStatus, ret } @@ -1182,7 +1224,7 @@ func (l *library) DeviceOnSameBoard(device1 Device, device2 Device) (int, Return func (device1 nvmlDevice) OnSameBoard(device2 Device) (int, Return) { var onSameBoard int32 - ret := nvmlDeviceOnSameBoard(device1, device2.(nvmlDevice), &onSameBoard) + ret := nvmlDeviceOnSameBoard(device1, nvmlDeviceHandle(device2), &onSameBoard) return int(onSameBoard), ret } diff --git a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go index 783514b1..acdb2e0c 100644 --- a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go +++ b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go @@ -93,7 +93,7 @@ func (device nvmlDevice) GpmSampleGet(gpmSample GpmSample) Return { } func (gpmSample nvmlGpmSample) Get(device Device) Return { - return nvmlGpmSampleGet(device.(nvmlDevice), gpmSample) + return nvmlGpmSampleGet(nvmlDeviceHandle(device), gpmSample) } // nvml.GpmQueryDeviceSupport() @@ -137,5 +137,5 @@ func (device nvmlDevice) GpmMigSampleGet(gpuInstanceId int, gpmSample GpmSample) } func (gpmSample nvmlGpmSample) MigGet(device Device, gpuInstanceId int) Return { - return nvmlGpmMigSampleGet(device.(nvmlDevice), uint32(gpuInstanceId), gpmSample) + return nvmlGpmMigSampleGet(nvmlDeviceHandle(device), uint32(gpuInstanceId), gpmSample) } diff --git a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go index bd800771..da495242 100644 --- a/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go +++ b/vendor/github.com/NVIDIA/go-nvml/pkg/nvml/vgpu.go @@ -142,7 +142,7 @@ func (device nvmlDevice) VgpuTypeGetMaxInstances(vgpuTypeId VgpuTypeId) (int, Re func (vgpuTypeId nvmlVgpuTypeId) GetMaxInstances(device Device) (int, Return) { var vgpuInstanceCount uint32 - ret := nvmlVgpuTypeGetMaxInstances(device.(nvmlDevice), vgpuTypeId, &vgpuInstanceCount) + ret := nvmlVgpuTypeGetMaxInstances(nvmlDeviceHandle(device), vgpuTypeId, &vgpuInstanceCount) return int(vgpuInstanceCount), ret } diff --git a/vendor/modules.txt b/vendor/modules.txt index 13f52577..eabcaf21 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -6,7 +6,7 @@ github.com/NVIDIA/go-nvlib/pkg/nvpci github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes github.com/NVIDIA/go-nvlib/pkg/nvpci/mmio github.com/NVIDIA/go-nvlib/pkg/pciids -# github.com/NVIDIA/go-nvml v0.12.0-5 +# github.com/NVIDIA/go-nvml v0.12.0-6 ## explicit; go 1.20 github.com/NVIDIA/go-nvml/pkg/dl github.com/NVIDIA/go-nvml/pkg/nvml