Return nvcdi.spec.Interface from GetSpec

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-02-20 14:29:52 +02:00
parent 2e95e04359
commit 6d6cd56196
4 changed files with 41 additions and 4 deletions

View File

@ -17,6 +17,7 @@
package nvcdi package nvcdi
import ( import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -33,6 +34,7 @@ const (
// Interface defines the API for the nvcdi package // Interface defines the API for the nvcdi package
type Interface interface { type Interface interface {
GetSpec() (spec.Interface, error)
GetCommonEdits() (*cdi.ContainerEdits, error) GetCommonEdits() (*cdi.ContainerEdits, error)
GetAllDeviceSpecs() ([]specs.Device, error) GetAllDeviceSpecs() ([]specs.Device, error)
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type nvmllib nvcdilib
var _ Interface = (*nvmllib)(nil) var _ Interface = (*nvmllib)(nil)
// GetSpec should not be called for nvmllib
func (l *nvmllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices. // GetAllDeviceSpecs returns the device specs for all available devices.
func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
var deviceSpecs []specs.Device var deviceSpecs []specs.Device

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
@ -29,6 +30,11 @@ type wsllib nvcdilib
var _ Interface = (*wsllib)(nil) var _ Interface = (*wsllib)(nil)
// GetSpec should not be called for wsllib
func (l *wsllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("Unexpected call to wsllib.GetSpec()")
}
// GetAllDeviceSpecs returns the device specs for all available devices. // GetAllDeviceSpecs returns the device specs for all available devices.
func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
device := newDXGDeviceDiscoverer(l.logger, l.driverRoot) device := newDXGDeviceDiscoverer(l.logger, l.driverRoot)

View File

@ -17,12 +17,17 @@
package nvcdi package nvcdi
import ( import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
) )
type wrapper struct {
Interface
}
type nvcdilib struct { type nvcdilib struct {
logger *logrus.Logger logger *logrus.Logger
nvmllib nvml.Interface nvmllib nvml.Interface
@ -60,6 +65,7 @@ func New(opts ...Option) Interface {
l.infolib = info.New() l.infolib = info.New()
} }
var lib Interface
switch l.resolveMode() { switch l.resolveMode() {
case ModeNvml: case ModeNvml:
if l.nvmllib == nil { if l.nvmllib == nil {
@ -69,13 +75,30 @@ func New(opts ...Option) Interface {
l.devicelib = device.New(device.WithNvml(l.nvmllib)) l.devicelib = device.New(device.WithNvml(l.nvmllib))
} }
return (*nvmllib)(l) lib = (*nvmllib)(l)
case ModeWsl: case ModeWsl:
return (*wsllib)(l) lib = (*wsllib)(l)
default:
// TODO: We would like to return an error here instead of panicking
panic("Unknown mode")
} }
// TODO: We want an error here. return &wrapper{Interface: lib}
return nil }
// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface.
func (l *wrapper) GetSpec() (spec.Interface, error) {
deviceSpecs, err := l.GetAllDeviceSpecs()
if err != nil {
return nil, err
}
edits, err := l.GetCommonEdits()
if err != nil {
return nil, err
}
return spec.New(deviceSpecs, *edits.ContainerEdits)
} }
// resolveMode resolves the mode for CDI spec generation based on the current system. // resolveMode resolves the mode for CDI spec generation based on the current system.