mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-06-26 18:18:24 +00:00
Update to github.com/NVIDIA/go-nvlib@f3264c8a6a7a
Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
This commit is contained in:
@@ -1,76 +0,0 @@
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type identifier string
|
||||
|
||||
// isGPUIndex checks if an identifier is a full GPU index
|
||||
func (i identifier) isGpuIndex() bool {
|
||||
if _, err := strconv.ParseUint(string(i), 10, 0); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isMigIndex checks if an identifier is a MIG index
|
||||
func (i identifier) isMigIndex() bool {
|
||||
split := strings.SplitN(string(i), ":", 2)
|
||||
if len(split) != 2 {
|
||||
return false
|
||||
}
|
||||
for _, s := range split {
|
||||
if _, err := strconv.ParseUint(s, 10, 0); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isUUID checks if an identifier is a UUID
|
||||
func (i identifier) isUUID() bool {
|
||||
return i.isGpuUUID() || i.isMigUUID()
|
||||
}
|
||||
|
||||
// isGpuUUID checks if an identifier is a GPU UUID
|
||||
// A GPU UUID must be of the form GPU-b1028956-cfa2-0990-bf4a-5da9abb51763
|
||||
func (i identifier) isGpuUUID() bool {
|
||||
if !strings.HasPrefix(string(i), "GPU-") {
|
||||
return false
|
||||
}
|
||||
_, err := uuid.Parse(strings.TrimPrefix(string(i), "GPU-"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// isMigUUID checks if an identifier is a MIG UUID
|
||||
// A MIG UUID can be of one of two forms:
|
||||
// - MIG-b1028956-cfa2-0990-bf4a-5da9abb51763
|
||||
// - MIG-GPU-b1028956-cfa2-0990-bf4a-5da9abb51763/3/0
|
||||
func (i identifier) isMigUUID() bool {
|
||||
if !strings.HasPrefix(string(i), "MIG-") {
|
||||
return false
|
||||
}
|
||||
suffix := strings.TrimPrefix(string(i), "MIG-")
|
||||
_, err := uuid.Parse(suffix)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
split := strings.SplitN(suffix, "/", 3)
|
||||
if len(split) != 3 {
|
||||
return false
|
||||
}
|
||||
if !identifier(split[0]).isGpuUUID() {
|
||||
return false
|
||||
}
|
||||
for _, s := range split[1:] {
|
||||
_, err := strconv.ParseUint(s, 10, 0)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsGpuIndex(t *testing.T) {
|
||||
testCases := []struct {
|
||||
id string
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"0", true},
|
||||
{"1", true},
|
||||
{"not an integer", false},
|
||||
}
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
actual := identifier(tc.id).isGpuIndex()
|
||||
require.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMigIndex(t *testing.T) {
|
||||
testCases := []struct {
|
||||
id string
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"0", false},
|
||||
{"not an integer", false},
|
||||
{"0:0", true},
|
||||
{"0:0:0", false},
|
||||
{"0:0.0", false},
|
||||
{"0:foo", false},
|
||||
{"foo:0", false},
|
||||
}
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
actual := identifier(tc.id).isMigIndex()
|
||||
require.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGpuUUID(t *testing.T) {
|
||||
testCases := []struct {
|
||||
id string
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"0", false},
|
||||
{"not an integer", false},
|
||||
{"GPU-foo", false},
|
||||
{"GPU-ebd34bdf-1083-eaac-2aff-4b71a022f9bd", true},
|
||||
{"MIG-ebd34bdf-1083-eaac-2aff-4b71a022f9bd", false},
|
||||
{"ebd34bdf-1083-eaac-2aff-4b71a022f9bd", false},
|
||||
}
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
actual := identifier(tc.id).isGpuUUID()
|
||||
require.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMigUUID(t *testing.T) {
|
||||
testCases := []struct {
|
||||
id string
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"0", false},
|
||||
{"not an integer", false},
|
||||
{"MIG-foo", false},
|
||||
{"MIG-ebd34bdf-1083-eaac-2aff-4b71a022f9bd", true},
|
||||
{"GPU-ebd34bdf-1083-eaac-2aff-4b71a022f9bd", false},
|
||||
{"ebd34bdf-1083-eaac-2aff-4b71a022f9bd", false},
|
||||
}
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
actual := identifier(tc.id).isMigUUID()
|
||||
require.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -133,20 +133,20 @@ func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, erro
|
||||
|
||||
func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) {
|
||||
var err error
|
||||
devID := identifier(id)
|
||||
devID := device.Identifier(id)
|
||||
|
||||
if devID.isUUID() {
|
||||
if devID.IsUUID() {
|
||||
return l.nvmllib.DeviceGetHandleByUUID(id)
|
||||
}
|
||||
|
||||
if devID.isGpuIndex() {
|
||||
if devID.IsGpuIndex() {
|
||||
if idx, err := strconv.Atoi(id); err == nil {
|
||||
return l.nvmllib.DeviceGetHandleByIndex(idx)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
|
||||
}
|
||||
|
||||
if devID.isMigIndex() {
|
||||
if devID.IsMigIndex() {
|
||||
var gpuIdx, migIdx int
|
||||
var parent nvml.Device
|
||||
split := strings.SplitN(id, ":", 2)
|
||||
|
||||
Reference in New Issue
Block a user