mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-01-22 18:47:32 +00:00
Construct nvml-based CDI lib based on mode
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
20d6e9af04
commit
d226925fe7
@ -20,27 +20,15 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||
)
|
||||
|
||||
// GetCommonEdits generates a CDI specification that can be used for ANY devices
|
||||
func (l *nvcdilib) GetCommonEdits() (*cdi.ContainerEdits, error) {
|
||||
common, err := newCommonDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err)
|
||||
}
|
||||
|
||||
return edits.FromDiscoverer(common)
|
||||
}
|
||||
|
||||
// newCommonDiscoverer returns a discoverer for entities that are not associated with a specific CDI device.
|
||||
// newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device.
|
||||
// This includes driver libraries and meta devices, for example.
|
||||
func newCommonDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) {
|
||||
func newCommonNVMLDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) {
|
||||
metaDevices := discover.NewDeviceDiscoverer(
|
||||
logger,
|
||||
lookup.NewCharDeviceLocator(
|
@ -33,7 +33,7 @@ import (
|
||||
)
|
||||
|
||||
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
|
||||
func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) {
|
||||
func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) {
|
||||
edits, err := l.GetGPUDeviceEdits(d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get edits for device: %v", err)
|
||||
@ -53,7 +53,7 @@ func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, err
|
||||
}
|
||||
|
||||
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
|
||||
func (l *nvcdilib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
|
||||
func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
|
||||
device, err := newFullGPUDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
|
93
pkg/nvcdi/lib-nvml.go
Normal file
93
pkg/nvcdi/lib-nvml.go
Normal file
@ -0,0 +1,93 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
|
||||
)
|
||||
|
||||
type nvmllib nvcdilib
|
||||
|
||||
var _ Interface = (*nvmllib)(nil)
|
||||
|
||||
// GetAllDeviceSpecs returns the device specs for all available devices.
|
||||
func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
|
||||
gpuDeviceSpecs, err := l.getGPUDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...)
|
||||
|
||||
migDeviceSpecs, err := l.getMigDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, migDeviceSpecs...)
|
||||
|
||||
return deviceSpecs, nil
|
||||
}
|
||||
|
||||
// GetCommonEdits generates a CDI specification that can be used for ANY devices
|
||||
func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
|
||||
common, err := newCommonNVMLDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err)
|
||||
}
|
||||
|
||||
return edits.FromDiscoverer(common)
|
||||
}
|
||||
|
||||
func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitDevices(func(i int, d device.Device) error {
|
||||
deviceSpec, err := l.GetGPUDeviceSpecs(i, d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
}
|
||||
|
||||
func (l *nvmllib) getMigDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
|
||||
deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
}
|
@ -17,9 +17,6 @@
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||
@ -41,12 +38,8 @@ func New(opts ...Option) Interface {
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
|
||||
if l.nvmllib == nil {
|
||||
l.nvmllib = nvml.New()
|
||||
}
|
||||
if l.devicelib == nil {
|
||||
l.devicelib = device.New(device.WithNvml(l.nvmllib))
|
||||
if l.mode == "" {
|
||||
l.mode = "nvml"
|
||||
}
|
||||
if l.logger == nil {
|
||||
l.logger = logrus.StandardLogger()
|
||||
@ -61,58 +54,18 @@ func New(opts ...Option) Interface {
|
||||
l.nvidiaCTKPath = "/usr/bin/nvidia-ctk"
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// GetAllDeviceSpecs returns the device specs for all available devices.
|
||||
func (l *nvcdilib) GetAllDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
|
||||
gpuDeviceSpecs, err := l.getGPUDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...)
|
||||
|
||||
migDeviceSpecs, err := l.getMigDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, migDeviceSpecs...)
|
||||
|
||||
return deviceSpecs, nil
|
||||
}
|
||||
|
||||
func (l *nvcdilib) getGPUDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitDevices(func(i int, d device.Device) error {
|
||||
deviceSpec, err := l.GetGPUDeviceSpecs(i, d)
|
||||
if err != nil {
|
||||
return err
|
||||
switch l.mode {
|
||||
case "nvml":
|
||||
if l.nvmllib == nil {
|
||||
l.nvmllib = nvml.New()
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
}
|
||||
|
||||
func (l *nvcdilib) getMigDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
|
||||
deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig)
|
||||
if err != nil {
|
||||
return err
|
||||
if l.devicelib == nil {
|
||||
l.devicelib = device.New(device.WithNvml(l.nvmllib))
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
return (*nvmllib)(l)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
|
||||
// TODO: We want an error here.
|
||||
return nil
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ import (
|
||||
)
|
||||
|
||||
// GetMIGDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
|
||||
func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) {
|
||||
func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) {
|
||||
edits, err := l.GetMIGDeviceEdits(d, mig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get edits for device: %v", err)
|
||||
@ -50,7 +50,7 @@ func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.M
|
||||
}
|
||||
|
||||
// GetMIGDeviceEdits returns the CDI edits for the MIG device represented by 'mig' on 'parent'.
|
||||
func (l *nvcdilib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) {
|
||||
func (l *nvmllib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) {
|
||||
gpu, ret := parent.GetMinorNumber()
|
||||
if ret != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting GPU minor: %v", ret)
|
Loading…
Reference in New Issue
Block a user