diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 855277c7..4a1079ac 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -18,13 +18,9 @@ package modifier import ( "fmt" - "strconv" "strings" - nvdevice "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" - "github.com/NVIDIA/go-nvlib/pkg/nvml" "tags.cncf.io/container-device-interface/pkg/parser" - "tags.cncf.io/container-device-interface/specs-go" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" @@ -193,56 +189,15 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devic return nil, fmt.Errorf("failed to construct CDI library: %w", err) } - names := []string{} + identifiers := []string{} for _, device := range devices { - _, _, name := parser.ParseDevice(device) - if name == "all" { - return cdilib.GetSpec() - } - names = append(names, name) + _, _, id := parser.ParseDevice(device) + identifiers = append(identifiers, id) } - // Note: The below code currently only supports generating CDI spec modifications - // for full-GPUs, specified either by index or UUID. MIG devices are not - // supported. - nvmlLib := nvml.New() - ret := nvmlLib.Init() - if ret != nvml.SUCCESS { - return nil, fmt.Errorf("failed to initialized NVML: %w", ret) - } - nvdevice := nvdevice.New(nvdevice.WithNvml(nvmlLib)) - - deviceSpecs := []specs.Device{} - for _, name := range names { - logger.Debugf("Getting CDI spec edits for device %q", name) - // Get a device handle by either index or UUID - var nvmlDevice nvml.Device - if idx, err := strconv.Atoi(name); err == nil { - nvmlDevice, err = nvmlLib.DeviceGetHandleByIndex(idx) - if err != nvml.SUCCESS { - return nil, fmt.Errorf("failed to get device handle for index '%v': %w", idx, err) - } - } else { - nvmlDevice, err = nvmlLib.DeviceGetHandleByUUID(name) - if err != nvml.SUCCESS { - return nil, fmt.Errorf("failed to get device handle for UUID '%v': %w", name, err) - } - } - - nvlibDevice, err := nvdevice.NewDevice(nvmlDevice) - if err != nil { - return nil, fmt.Errorf("failed to construct device: %w", err) - } - - gpuEdits, err := cdilib.GetGPUDeviceEdits(nvlibDevice) - if err != nil { - return nil, fmt.Errorf("failed to get CDI spec edits for GPU %q: %w", name, err) - } - gpuDevice := specs.Device{ - Name: name, - ContainerEdits: *gpuEdits.ContainerEdits, - } - deviceSpecs = append(deviceSpecs, gpuDevice) + deviceSpecs, err := cdilib.GetDeviceSpecsByID(identifiers...) + if err != nil { + return nil, fmt.Errorf("failed to get CDI device specs: %w", err) } commonEdits, err := cdilib.GetCommonEdits() diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 43aad634..27c264de 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -51,4 +51,5 @@ type Interface interface { GetGPUDeviceSpecs(int, device.Device) (*specs.Device, error) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) + GetDeviceSpecsByID(...string) ([]specs.Device, error) } diff --git a/pkg/nvcdi/gds.go b/pkg/nvcdi/gds.go index cb1bf760..74a186c1 100644 --- a/pkg/nvcdi/gds.go +++ b/pkg/nvcdi/gds.go @@ -81,3 +81,10 @@ func (l *gdslib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.Contai func (l *gdslib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) { return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported") } + +// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by +// the provided identifiers, where an identifier is an index or UUID of a valid +// GPU device. +func (l *gdslib) GetDeviceSpecsByID(...string) ([]specs.Device, error) { + return nil, fmt.Errorf("GetDeviceSpecsByID is not supported") +} diff --git a/pkg/nvcdi/lib-csv.go b/pkg/nvcdi/lib-csv.go index 86d86f93..31604345 100644 --- a/pkg/nvcdi/lib-csv.go +++ b/pkg/nvcdi/lib-csv.go @@ -94,3 +94,10 @@ func (l *csvlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.Contai func (l *csvlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) { return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported for CSV files") } + +// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by +// the provided identifiers, where an identifier is an index or UUID of a valid +// GPU device. +func (l *csvlib) GetDeviceSpecsByID(...string) ([]specs.Device, error) { + return nil, fmt.Errorf("GetDeviceSpecsByID is not supported for CSV files") +} diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index 3ce68ec9..41ae23ac 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -18,6 +18,7 @@ package nvcdi import ( "fmt" + "strconv" "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvml" @@ -75,6 +76,72 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) { return edits.FromDiscoverer(common) } +// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by +// the provided identifiers, where an identifier is an index or UUID of a valid +// GPU device. +// TODO: support identifiers that correspond to MIG devices +func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) { + for _, id := range identifiers { + if id == "all" { + return l.GetAllDeviceSpecs() + } + } + + var deviceSpecs []specs.Device + + if r := l.nvmllib.Init(); r != nvml.SUCCESS { + return nil, fmt.Errorf("failed to initialize NVML: %w", r) + } + defer func() { + if r := l.nvmllib.Shutdown(); r != nvml.SUCCESS { + l.logger.Warningf("failed to shutdown NVML: %w", r) + } + }() + + nvmlDevices, err := l.getNVMLDevicesByID(identifiers...) + if err != nil { + return nil, fmt.Errorf("failed to get NVML device handles: %w", err) + } + + for i, nvmlDevice := range nvmlDevices { + nvlibDevice, err := l.devicelib.NewDevice(nvmlDevice) + if err != nil { + return nil, fmt.Errorf("failed to construct device: %w", err) + } + deviceEdits, err := l.GetGPUDeviceEdits(nvlibDevice) + if err != nil { + return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err) + } + deviceSpec := specs.Device{ + Name: identifiers[i], + ContainerEdits: *deviceEdits.ContainerEdits, + } + deviceSpecs = append(deviceSpecs, deviceSpec) + } + + return deviceSpecs, nil +} + +// TODO: move this to go-nvlib? +func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, error) { + devices := []nvml.Device{} + for _, id := range identifiers { + if dev, err := l.nvmllib.DeviceGetHandleByUUID(id); err == nvml.SUCCESS { + devices = append(devices, dev) + continue + } + // TODO: check for a MIG device index + if idx, err := strconv.Atoi(id); err == nil { + if dev, err := l.nvmllib.DeviceGetHandleByIndex(idx); err == nvml.SUCCESS { + devices = append(devices, dev) + continue + } + } + return nil, fmt.Errorf("failed to get NVML device handle for identifier %q", id) + } + return devices, nil +} + func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) { var deviceSpecs []specs.Device err := l.devicelib.VisitDevices(func(i int, d device.Device) error { diff --git a/pkg/nvcdi/lib-wsl.go b/pkg/nvcdi/lib-wsl.go index b01c8268..385007cf 100644 --- a/pkg/nvcdi/lib-wsl.go +++ b/pkg/nvcdi/lib-wsl.go @@ -81,3 +81,10 @@ func (l *wsllib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.Contai func (l *wsllib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) { return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported on WSL") } + +// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by +// the provided identifiers, where an identifier is an index or UUID of a valid +// GPU device. +func (l *wsllib) GetDeviceSpecsByID(...string) ([]specs.Device, error) { + return nil, fmt.Errorf("GetDeviceSpecsByID is not supported on WSL") +} diff --git a/pkg/nvcdi/management.go b/pkg/nvcdi/management.go index 36b4b27b..8c3d4b32 100644 --- a/pkg/nvcdi/management.go +++ b/pkg/nvcdi/management.go @@ -188,3 +188,10 @@ func (m *managementlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi func (m *managementlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) { return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported") } + +// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by +// the provided identifiers, where an identifier is an index or UUID of a valid +// GPU device. +func (l *managementlib) GetDeviceSpecsByID(...string) ([]specs.Device, error) { + return nil, fmt.Errorf("GetDeviceSpecsByID is not supported") +} diff --git a/pkg/nvcdi/mofed.go b/pkg/nvcdi/mofed.go index 3f56b2d5..607b7baf 100644 --- a/pkg/nvcdi/mofed.go +++ b/pkg/nvcdi/mofed.go @@ -81,3 +81,10 @@ func (l *mofedlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.Cont func (l *mofedlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) { return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported") } + +// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by +// the provided identifiers, where an identifier is an index or UUID of a valid +// GPU device. +func (l *mofedlib) GetDeviceSpecsByID(...string) ([]specs.Device, error) { + return nil, fmt.Errorf("GetDeviceSpecsByID is not supported") +}