From 35b23c5a2cd93db97b91174cdbac3fe19f5c1c3a Mon Sep 17 00:00:00 2001 From: Christopher Desiniotis Date: Wed, 13 Dec 2023 10:16:09 -0800 Subject: [PATCH] Accept device.Identifiers for requesting CDI specs This change moves from using strings to useing device.Identifiers as input for requesting CDI specifications for specific devices. Signed-off-by: Christopher Desiniotis Signed-off-by: Evan Lezar --- pkg/nvcdi/lib-nvml.go | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index 95236651..ab7cb8ba 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -80,7 +80,17 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) { // GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by // the provided identifiers, where an identifier is an index or UUID of a valid // GPU device. -func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) { +// Deprecated: Use GetDeviceSpecsBy instead. +func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) { + var identifiers []device.Identifier + for _, id := range ids { + identifiers = append(identifiers, device.Identifier(id)) + } + return l.GetDeviceSpecsBy(identifiers...) +} + +// GetDeviceSpecsBy is not supported for the gdslib specs. +func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.Device, error) { for _, id := range identifiers { if id == "all" { return l.GetAllDeviceSpecs() @@ -109,7 +119,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err) } deviceSpec := specs.Device{ - Name: identifiers[i], + Name: string(identifiers[i]), ContainerEdits: *deviceEdits.ContainerEdits, } deviceSpecs = append(deviceSpecs, deviceSpec) @@ -119,7 +129,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err } // TODO: move this to go-nvlib? -func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, error) { +func (l *nvmllib) getNVMLDevicesByID(identifiers ...device.Identifier) ([]nvml.Device, error) { var devices []nvml.Device for _, id := range identifiers { dev, err := l.getNVMLDeviceByID(id) @@ -131,25 +141,24 @@ func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, erro return devices, nil } -func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) { +func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) { var err error - devID := device.Identifier(id) - if devID.IsUUID() { - return l.nvmllib.DeviceGetHandleByUUID(id) + if id.IsUUID() { + return l.nvmllib.DeviceGetHandleByUUID(string(id)) } - if devID.IsGpuIndex() { - if idx, err := strconv.Atoi(id); err == nil { + if id.IsGpuIndex() { + if idx, err := strconv.Atoi(string(id)); err == nil { return l.nvmllib.DeviceGetHandleByIndex(idx) } return nil, fmt.Errorf("failed to convert device index to an int: %w", err) } - if devID.IsMigIndex() { + if id.IsMigIndex() { var gpuIdx, migIdx int var parent nvml.Device - split := strings.SplitN(id, ":", 2) + split := strings.SplitN(string(id), ":", 2) if gpuIdx, err = strconv.Atoi(split[0]); err != nil { return nil, fmt.Errorf("failed to convert device index to an int: %w", err) }