diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 267010e9..85bace99 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -17,6 +17,7 @@ package nvcdi import ( + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" @@ -33,6 +34,7 @@ const ( // Interface defines the API for the nvcdi package type Interface interface { + GetSpec() (spec.Interface, error) GetCommonEdits() (*cdi.ContainerEdits, error) GetAllDeviceSpecs() ([]specs.Device, error) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index aaca382e..8fa29c11 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" @@ -29,6 +30,11 @@ type nvmllib nvcdilib var _ Interface = (*nvmllib)(nil) +// GetSpec should not be called for nvmllib +func (l *nvmllib) GetSpec() (spec.Interface, error) { + return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()") +} + // GetAllDeviceSpecs returns the device specs for all available devices. func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { var deviceSpecs []specs.Device diff --git a/pkg/nvcdi/lib-wsl.go b/pkg/nvcdi/lib-wsl.go index d901995c..937d1d33 100644 --- a/pkg/nvcdi/lib-wsl.go +++ b/pkg/nvcdi/lib-wsl.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" @@ -29,6 +30,11 @@ type wsllib nvcdilib var _ Interface = (*wsllib)(nil) +// GetSpec should not be called for wsllib +func (l *wsllib) GetSpec() (spec.Interface, error) { + return nil, fmt.Errorf("Unexpected call to wsllib.GetSpec()") +} + // GetAllDeviceSpecs returns the device specs for all available devices. func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { device := newDXGDeviceDiscoverer(l.logger, l.driverRoot) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index be5554eb..aa90e396 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -17,12 +17,17 @@ package nvcdi import ( + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) +type wrapper struct { + Interface +} + type nvcdilib struct { logger *logrus.Logger nvmllib nvml.Interface @@ -60,6 +65,7 @@ func New(opts ...Option) Interface { l.infolib = info.New() } + var lib Interface switch l.resolveMode() { case ModeNvml: if l.nvmllib == nil { @@ -69,13 +75,30 @@ func New(opts ...Option) Interface { l.devicelib = device.New(device.WithNvml(l.nvmllib)) } - return (*nvmllib)(l) + lib = (*nvmllib)(l) case ModeWsl: - return (*wsllib)(l) + lib = (*wsllib)(l) + default: + // TODO: We would like to return an error here instead of panicking + panic("Unknown mode") } - // TODO: We want an error here. - return nil + return &wrapper{Interface: lib} +} + +// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface. +func (l *wrapper) GetSpec() (spec.Interface, error) { + deviceSpecs, err := l.GetAllDeviceSpecs() + if err != nil { + return nil, err + } + + edits, err := l.GetCommonEdits() + if err != nil { + return nil, err + } + + return spec.New(deviceSpecs, *edits.ContainerEdits) } // resolveMode resolves the mode for CDI spec generation based on the current system.