Return nvcdi.spec.Interface from GetSpec

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-02-20 14:29:52 +02:00
parent 2e95e04359
commit 6d6cd56196
4 changed files with 41 additions and 4 deletions

View File

@ -17,6 +17,7 @@
package nvcdi
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -33,6 +34,7 @@ const (
// Interface defines the API for the nvcdi package
type Interface interface {
GetSpec() (spec.Interface, error)
GetCommonEdits() (*cdi.ContainerEdits, error)
GetAllDeviceSpecs() ([]specs.Device, error)
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)

View File

@ -20,6 +20,7 @@ import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type nvmllib nvcdilib
var _ Interface = (*nvmllib)(nil)
// GetSpec should not be called for nvmllib
func (l *nvmllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices.
func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
var deviceSpecs []specs.Device

View File

@ -20,6 +20,7 @@ import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type wsllib nvcdilib
var _ Interface = (*wsllib)(nil)
// GetSpec should not be called for wsllib
func (l *wsllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to wsllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices.
func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
device := newDXGDeviceDiscoverer(l.logger, l.driverRoot)

View File

@ -17,12 +17,17 @@
package nvcdi
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
type wrapper struct {
Interface
}
type nvcdilib struct {
logger *logrus.Logger
nvmllib nvml.Interface
@ -60,6 +65,7 @@ func New(opts ...Option) Interface {
l.infolib = info.New()
}
var lib Interface
switch l.resolveMode() {
case ModeNvml:
if l.nvmllib == nil {
@ -69,13 +75,30 @@ func New(opts ...Option) Interface {
l.devicelib = device.New(device.WithNvml(l.nvmllib))
}
return (*nvmllib)(l)
lib = (*nvmllib)(l)
case ModeWsl:
return (*wsllib)(l)
lib = (*wsllib)(l)
default:
// TODO: We would like to return an error here instead of panicking
panic("Unknown mode")
}
// TODO: We want an error here.
return nil
return &wrapper{Interface: lib}
}
// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface.
func (l *wrapper) GetSpec() (spec.Interface, error) {
deviceSpecs, err := l.GetAllDeviceSpecs()
if err != nil {
return nil, err
}
edits, err := l.GetCommonEdits()
if err != nil {
return nil, err
}
return spec.New(deviceSpecs, *edits.ContainerEdits)
}
// resolveMode resolves the mode for CDI spec generation based on the current system.