diff --git a/pkg/nvcdi/common.go b/pkg/nvcdi/common-nvml.go similarity index 68% rename from pkg/nvcdi/common.go rename to pkg/nvcdi/common-nvml.go index 1d04d420..df81fc29 100644 --- a/pkg/nvcdi/common.go +++ b/pkg/nvcdi/common-nvml.go @@ -20,27 +20,15 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" - "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) -// GetCommonEdits generates a CDI specification that can be used for ANY devices -func (l *nvcdilib) GetCommonEdits() (*cdi.ContainerEdits, error) { - common, err := newCommonDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) - if err != nil { - return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err) - } - - return edits.FromDiscoverer(common) -} - -// newCommonDiscoverer returns a discoverer for entities that are not associated with a specific CDI device. +// newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device. // This includes driver libraries and meta devices, for example. -func newCommonDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { +func newCommonNVMLDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { metaDevices := discover.NewDeviceDiscoverer( logger, lookup.NewCharDeviceLocator( diff --git a/pkg/nvcdi/driver.go b/pkg/nvcdi/driver-nvml.go similarity index 100% rename from pkg/nvcdi/driver.go rename to pkg/nvcdi/driver-nvml.go diff --git a/pkg/nvcdi/full-gpu.go b/pkg/nvcdi/full-gpu-nvml.go similarity index 97% rename from pkg/nvcdi/full-gpu.go rename to pkg/nvcdi/full-gpu-nvml.go index 7e61477c..9dc6780e 100644 --- a/pkg/nvcdi/full-gpu.go +++ b/pkg/nvcdi/full-gpu-nvml.go @@ -33,7 +33,7 @@ import ( ) // GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'. -func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) { +func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) { edits, err := l.GetGPUDeviceEdits(d) if err != nil { return nil, fmt.Errorf("failed to get edits for device: %v", err) @@ -53,7 +53,7 @@ func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, err } // GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'. -func (l *nvcdilib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) { +func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) { device, err := newFullGPUDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, d) if err != nil { return nil, fmt.Errorf("failed to create device discoverer: %v", err) diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go new file mode 100644 index 00000000..aaca382e --- /dev/null +++ b/pkg/nvcdi/lib-nvml.go @@ -0,0 +1,93 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" +) + +type nvmllib nvcdilib + +var _ Interface = (*nvmllib)(nil) + +// GetAllDeviceSpecs returns the device specs for all available devices. +func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { + var deviceSpecs []specs.Device + + gpuDeviceSpecs, err := l.getGPUDeviceSpecs() + if err != nil { + return nil, err + } + deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...) + + migDeviceSpecs, err := l.getMigDeviceSpecs() + if err != nil { + return nil, err + } + deviceSpecs = append(deviceSpecs, migDeviceSpecs...) + + return deviceSpecs, nil +} + +// GetCommonEdits generates a CDI specification that can be used for ANY devices +func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) { + common, err := newCommonNVMLDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) + if err != nil { + return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err) + } + + return edits.FromDiscoverer(common) +} + +func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) { + var deviceSpecs []specs.Device + err := l.devicelib.VisitDevices(func(i int, d device.Device) error { + deviceSpec, err := l.GetGPUDeviceSpecs(i, d) + if err != nil { + return err + } + deviceSpecs = append(deviceSpecs, *deviceSpec) + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) + } + return deviceSpecs, err +} + +func (l *nvmllib) getMigDeviceSpecs() ([]specs.Device, error) { + var deviceSpecs []specs.Device + err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error { + deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig) + if err != nil { + return err + } + deviceSpecs = append(deviceSpecs, *deviceSpec) + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) + } + return deviceSpecs, err +} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 592ff186..985e6850 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -17,9 +17,6 @@ package nvcdi import ( - "fmt" - - "github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" @@ -41,12 +38,8 @@ func New(opts ...Option) Interface { for _, opt := range opts { opt(l) } - - if l.nvmllib == nil { - l.nvmllib = nvml.New() - } - if l.devicelib == nil { - l.devicelib = device.New(device.WithNvml(l.nvmllib)) + if l.mode == "" { + l.mode = "nvml" } if l.logger == nil { l.logger = logrus.StandardLogger() @@ -61,58 +54,18 @@ func New(opts ...Option) Interface { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } - return l -} - -// GetAllDeviceSpecs returns the device specs for all available devices. -func (l *nvcdilib) GetAllDeviceSpecs() ([]specs.Device, error) { - var deviceSpecs []specs.Device - - gpuDeviceSpecs, err := l.getGPUDeviceSpecs() - if err != nil { - return nil, err - } - deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...) - - migDeviceSpecs, err := l.getMigDeviceSpecs() - if err != nil { - return nil, err - } - deviceSpecs = append(deviceSpecs, migDeviceSpecs...) - - return deviceSpecs, nil -} - -func (l *nvcdilib) getGPUDeviceSpecs() ([]specs.Device, error) { - var deviceSpecs []specs.Device - err := l.devicelib.VisitDevices(func(i int, d device.Device) error { - deviceSpec, err := l.GetGPUDeviceSpecs(i, d) - if err != nil { - return err + switch l.mode { + case "nvml": + if l.nvmllib == nil { + l.nvmllib = nvml.New() } - deviceSpecs = append(deviceSpecs, *deviceSpec) - - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) - } - return deviceSpecs, err -} - -func (l *nvcdilib) getMigDeviceSpecs() ([]specs.Device, error) { - var deviceSpecs []specs.Device - err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error { - deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig) - if err != nil { - return err + if l.devicelib == nil { + l.devicelib = device.New(device.WithNvml(l.nvmllib)) } - deviceSpecs = append(deviceSpecs, *deviceSpec) - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) + return (*nvmllib)(l) } - return deviceSpecs, err + + // TODO: We want an error here. + return nil } diff --git a/pkg/nvcdi/mig-device.go b/pkg/nvcdi/mig-device-nvml.go similarity index 94% rename from pkg/nvcdi/mig-device.go rename to pkg/nvcdi/mig-device-nvml.go index 3d0a91f2..7864ff91 100644 --- a/pkg/nvcdi/mig-device.go +++ b/pkg/nvcdi/mig-device-nvml.go @@ -30,7 +30,7 @@ import ( ) // GetMIGDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'. -func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) { +func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) { edits, err := l.GetMIGDeviceEdits(d, mig) if err != nil { return nil, fmt.Errorf("failed to get edits for device: %v", err) @@ -50,7 +50,7 @@ func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.M } // GetMIGDeviceEdits returns the CDI edits for the MIG device represented by 'mig' on 'parent'. -func (l *nvcdilib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) { +func (l *nvmllib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) { gpu, ret := parent.GetMinorNumber() if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting GPU minor: %v", ret)