mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 08:18:32 +00:00
Accept device.Identifiers for requesting CDI specs
This change moves from using strings to useing device.Identifiers as input for requesting CDI specifications for specific devices. Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com> Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
a442a5ed1f
commit
35b23c5a2c
@ -80,7 +80,17 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
|
|||||||
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
|
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
|
||||||
// the provided identifiers, where an identifier is an index or UUID of a valid
|
// the provided identifiers, where an identifier is an index or UUID of a valid
|
||||||
// GPU device.
|
// GPU device.
|
||||||
func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) {
|
// Deprecated: Use GetDeviceSpecsBy instead.
|
||||||
|
func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
|
||||||
|
var identifiers []device.Identifier
|
||||||
|
for _, id := range ids {
|
||||||
|
identifiers = append(identifiers, device.Identifier(id))
|
||||||
|
}
|
||||||
|
return l.GetDeviceSpecsBy(identifiers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceSpecsBy is not supported for the gdslib specs.
|
||||||
|
func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.Device, error) {
|
||||||
for _, id := range identifiers {
|
for _, id := range identifiers {
|
||||||
if id == "all" {
|
if id == "all" {
|
||||||
return l.GetAllDeviceSpecs()
|
return l.GetAllDeviceSpecs()
|
||||||
@ -109,7 +119,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
|
|||||||
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err)
|
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err)
|
||||||
}
|
}
|
||||||
deviceSpec := specs.Device{
|
deviceSpec := specs.Device{
|
||||||
Name: identifiers[i],
|
Name: string(identifiers[i]),
|
||||||
ContainerEdits: *deviceEdits.ContainerEdits,
|
ContainerEdits: *deviceEdits.ContainerEdits,
|
||||||
}
|
}
|
||||||
deviceSpecs = append(deviceSpecs, deviceSpec)
|
deviceSpecs = append(deviceSpecs, deviceSpec)
|
||||||
@ -119,7 +129,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: move this to go-nvlib?
|
// TODO: move this to go-nvlib?
|
||||||
func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, error) {
|
func (l *nvmllib) getNVMLDevicesByID(identifiers ...device.Identifier) ([]nvml.Device, error) {
|
||||||
var devices []nvml.Device
|
var devices []nvml.Device
|
||||||
for _, id := range identifiers {
|
for _, id := range identifiers {
|
||||||
dev, err := l.getNVMLDeviceByID(id)
|
dev, err := l.getNVMLDeviceByID(id)
|
||||||
@ -131,25 +141,24 @@ func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, erro
|
|||||||
return devices, nil
|
return devices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) {
|
func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
|
||||||
var err error
|
var err error
|
||||||
devID := device.Identifier(id)
|
|
||||||
|
|
||||||
if devID.IsUUID() {
|
if id.IsUUID() {
|
||||||
return l.nvmllib.DeviceGetHandleByUUID(id)
|
return l.nvmllib.DeviceGetHandleByUUID(string(id))
|
||||||
}
|
}
|
||||||
|
|
||||||
if devID.IsGpuIndex() {
|
if id.IsGpuIndex() {
|
||||||
if idx, err := strconv.Atoi(id); err == nil {
|
if idx, err := strconv.Atoi(string(id)); err == nil {
|
||||||
return l.nvmllib.DeviceGetHandleByIndex(idx)
|
return l.nvmllib.DeviceGetHandleByIndex(idx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
|
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if devID.IsMigIndex() {
|
if id.IsMigIndex() {
|
||||||
var gpuIdx, migIdx int
|
var gpuIdx, migIdx int
|
||||||
var parent nvml.Device
|
var parent nvml.Device
|
||||||
split := strings.SplitN(id, ":", 2)
|
split := strings.SplitN(string(id), ":", 2)
|
||||||
if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
|
if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
|
||||||
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
|
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user