From 31581469468d2d4e631a5984bf561ec0d9f3569e Mon Sep 17 00:00:00 2001 From: Christopher Desiniotis Date: Tue, 5 Dec 2023 18:38:00 -0800 Subject: [PATCH] Extend the 'runtime.nvidia.com/gpu' CDI device kind to support MIG devices specified by index or UUID Signed-off-by: Christopher Desiniotis --- pkg/nvcdi/lib-nvml.go | 62 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index 0946fc04..492e90cc 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -19,6 +19,7 @@ package nvcdi import ( "fmt" "strconv" + "strings" "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvml" @@ -79,7 +80,6 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) { // GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by // the provided identifiers, where an identifier is an index or UUID of a valid // GPU device. -// TODO: support identifiers that correspond to MIG devices func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) { for _, id := range identifiers { if id == "all" { @@ -104,11 +104,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err } for i, nvmlDevice := range nvmlDevices { - nvlibDevice, err := l.devicelib.NewDevice(nvmlDevice) - if err != nil { - return nil, fmt.Errorf("failed to construct device: %w", err) - } - deviceEdits, err := l.GetGPUDeviceEdits(nvlibDevice) + deviceEdits, err := l.getEditsForDevice(nvmlDevice) if err != nil { return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err) } @@ -151,12 +147,64 @@ func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) { } if devID.isMigIndex() { - return nil, fmt.Errorf("MIG index is not supported") + var gpuIdx, migIdx int + var parent nvml.Device + split := strings.SplitN(id, ":", 2) + if gpuIdx, err = strconv.Atoi(split[0]); err != nil { + return nil, fmt.Errorf("failed to convert device index to an int: %w", err) + } + if migIdx, err = strconv.Atoi(split[1]); err != nil { + return nil, fmt.Errorf("failed to convert device index to an int: %w", err) + } + if parent, err = l.nvmllib.DeviceGetHandleByIndex(gpuIdx); err != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get parent device handle: %w", err) + } + return parent.GetMigDeviceHandleByIndex(migIdx) } return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id) } +func (l *nvmllib) getEditsForDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) { + mig, err := nvmlDevice.IsMigDeviceHandle() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("failed to determine if device handle is a MIG device: %w", err) + } + if mig { + return l.getEditsForMIGDevice(nvmlDevice) + } + return l.getEditsForGPUDevice(nvmlDevice) +} + +func (l *nvmllib) getEditsForGPUDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) { + nvlibDevice, err := l.devicelib.NewDevice(nvmlDevice) + if err != nil { + return nil, fmt.Errorf("failed to construct device: %w", err) + } + deviceEdits, err := l.GetGPUDeviceEdits(nvlibDevice) + if err != nil { + return nil, fmt.Errorf("failed to get GPU device edits: %w", err) + } + + return deviceEdits, nil +} + +func (l *nvmllib) getEditsForMIGDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) { + nvmlParentDevice, ret := nvmlDevice.GetDeviceHandleFromMigDeviceHandle() + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get parent device handle: %w", ret) + } + nvlibMigDevice, err := l.devicelib.NewMigDevice(nvmlDevice) + if err != nil { + return nil, fmt.Errorf("failed to construct device: %w", err) + } + nvlibParentDevice, err := l.devicelib.NewDevice(nvmlParentDevice) + if err != nil { + return nil, fmt.Errorf("failed to construct parent device: %w", err) + } + return l.GetMIGDeviceEdits(nvlibParentDevice, nvlibMigDevice) +} + func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) { var deviceSpecs []specs.Device err := l.devicelib.VisitDevices(func(i int, d device.Device) error {