Accept device.Identifiers for requesting CDI specs

This change moves from using strings to useing device.Identifiers
as input for requesting CDI specifications for specific
devices.

Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Christopher Desiniotis 2023-12-13 10:16:09 -08:00 committed by Evan Lezar
parent a442a5ed1f
commit 35b23c5a2c

View File

@ -80,7 +80,17 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by // GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
// the provided identifiers, where an identifier is an index or UUID of a valid // the provided identifiers, where an identifier is an index or UUID of a valid
// GPU device. // GPU device.
func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) { // Deprecated: Use GetDeviceSpecsBy instead.
func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
var identifiers []device.Identifier
for _, id := range ids {
identifiers = append(identifiers, device.Identifier(id))
}
return l.GetDeviceSpecsBy(identifiers...)
}
// GetDeviceSpecsBy is not supported for the gdslib specs.
func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.Device, error) {
for _, id := range identifiers { for _, id := range identifiers {
if id == "all" { if id == "all" {
return l.GetAllDeviceSpecs() return l.GetAllDeviceSpecs()
@ -109,7 +119,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err) return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err)
} }
deviceSpec := specs.Device{ deviceSpec := specs.Device{
Name: identifiers[i], Name: string(identifiers[i]),
ContainerEdits: *deviceEdits.ContainerEdits, ContainerEdits: *deviceEdits.ContainerEdits,
} }
deviceSpecs = append(deviceSpecs, deviceSpec) deviceSpecs = append(deviceSpecs, deviceSpec)
@ -119,7 +129,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
} }
// TODO: move this to go-nvlib? // TODO: move this to go-nvlib?
func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, error) { func (l *nvmllib) getNVMLDevicesByID(identifiers ...device.Identifier) ([]nvml.Device, error) {
var devices []nvml.Device var devices []nvml.Device
for _, id := range identifiers { for _, id := range identifiers {
dev, err := l.getNVMLDeviceByID(id) dev, err := l.getNVMLDeviceByID(id)
@ -131,25 +141,24 @@ func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, erro
return devices, nil return devices, nil
} }
func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) { func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
var err error var err error
devID := device.Identifier(id)
if devID.IsUUID() { if id.IsUUID() {
return l.nvmllib.DeviceGetHandleByUUID(id) return l.nvmllib.DeviceGetHandleByUUID(string(id))
} }
if devID.IsGpuIndex() { if id.IsGpuIndex() {
if idx, err := strconv.Atoi(id); err == nil { if idx, err := strconv.Atoi(string(id)); err == nil {
return l.nvmllib.DeviceGetHandleByIndex(idx) return l.nvmllib.DeviceGetHandleByIndex(idx)
} }
return nil, fmt.Errorf("failed to convert device index to an int: %w", err) return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
} }
if devID.IsMigIndex() { if id.IsMigIndex() {
var gpuIdx, migIdx int var gpuIdx, migIdx int
var parent nvml.Device var parent nvml.Device
split := strings.SplitN(id, ":", 2) split := strings.SplitN(string(id), ":", 2)
if gpuIdx, err = strconv.Atoi(split[0]); err != nil { if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
return nil, fmt.Errorf("failed to convert device index to an int: %w", err) return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
} }