mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-06-26 18:18:24 +00:00
51
internal/discover/binaries.go
Normal file
51
internal/discover/binaries.go
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewBinaryMounts creates a discoverer for binaries using the specified root
|
||||
func NewBinaryMounts(root string) Discover {
|
||||
return NewBinaryMountsWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewBinaryMountsWithLogger creates a Mounts discoverer as with NewBinaryMounts
|
||||
// with the specified logger
|
||||
func NewBinaryMountsWithLogger(logger *log.Logger, root string) Discover {
|
||||
d := mounts{
|
||||
logger: logger,
|
||||
lookup: lookup.NewPathLocatorWithLogger(logger, root),
|
||||
required: requiredBinaries,
|
||||
}
|
||||
return &d
|
||||
}
|
||||
|
||||
// requiredBinaries defines a set of binaries and their labels
|
||||
var requiredBinaries = map[string][]string{
|
||||
"utility": {
|
||||
"nvidia-smi", /* System management interface */
|
||||
"nvidia-debugdump", /* GPU coredump utility */
|
||||
"nvidia-persistenced", /* Persistence mode utility */
|
||||
},
|
||||
"compute": {
|
||||
"nvidia-cuda-mps-control", /* Multi process service CLI */
|
||||
"nvidia-cuda-mps-server", /* Multi process service server */
|
||||
},
|
||||
}
|
||||
73
internal/discover/binaries_test.go
Normal file
73
internal/discover/binaries_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBinaries(t *testing.T) {
|
||||
binaryLookup := map[string]string{
|
||||
"nvidia-smi": "/usr/bin/nvidia-smi",
|
||||
"nvidia-persistenced": "/usr/bin/nvidia-persistenced",
|
||||
"nvidia-debugdump": "test-duplicates",
|
||||
"nvidia-cuda-mps-control": "test-duplicates",
|
||||
}
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
d := NewBinaryMountsWithLogger(logger, "")
|
||||
|
||||
// Override lookup for testing
|
||||
mockLocator := NewLocatorMockFromMap(binaryLookup)
|
||||
d.(*mounts).lookup = mockLocator
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := []Mount{
|
||||
{
|
||||
Path: "/usr/bin/nvidia-smi",
|
||||
},
|
||||
{
|
||||
Path: "/usr/bin/nvidia-persistenced",
|
||||
},
|
||||
{
|
||||
Path: "test-duplicates",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, len(expected), len(mounts))
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
|
||||
func TestNewBinariesConstructor(t *testing.T) {
|
||||
b := NewBinaryMounts("").(*mounts)
|
||||
|
||||
require.NotNil(t, b.logger)
|
||||
require.NotNil(t, b.lookup)
|
||||
}
|
||||
74
internal/discover/composite.go
Normal file
74
internal/discover/composite.go
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import "fmt"
|
||||
|
||||
// composite is a discoverer that contains a list of Discoverers. The output of the
|
||||
// Devices, Mounts, and Hooks functions is the concatenation of the output for each of the
|
||||
// elements in the list.
|
||||
type composite struct {
|
||||
discoverers []Discover
|
||||
}
|
||||
|
||||
var _ Discover = (*composite)(nil)
|
||||
|
||||
func (d composite) Devices() ([]Device, error) {
|
||||
var allDevices []Device
|
||||
|
||||
for i, di := range d.discoverers {
|
||||
devices, err := di.Devices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering devices for discoverer %v: %v", i, err)
|
||||
}
|
||||
allDevices = append(allDevices, devices...)
|
||||
}
|
||||
|
||||
return allDevices, nil
|
||||
}
|
||||
|
||||
func (d composite) Mounts() ([]Mount, error) {
|
||||
var allMounts []Mount
|
||||
|
||||
for i, di := range d.discoverers {
|
||||
mounts, err := di.Mounts()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering mounts for discoverer %v: %v", i, err)
|
||||
}
|
||||
allMounts = append(allMounts, mounts...)
|
||||
}
|
||||
|
||||
return allMounts, nil
|
||||
}
|
||||
|
||||
func (d composite) Hooks() ([]Hook, error) {
|
||||
var allHooks []Hook
|
||||
|
||||
for i, di := range d.discoverers {
|
||||
hooks, err := di.Hooks()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering hooks for discoverer %v: %v", i, err)
|
||||
}
|
||||
allHooks = append(allHooks, hooks...)
|
||||
}
|
||||
|
||||
return allHooks, nil
|
||||
}
|
||||
|
||||
func (d *composite) add(di ...Discover) {
|
||||
d.discoverers = append(d.discoverers, di...)
|
||||
}
|
||||
65
internal/discover/discover.go
Normal file
65
internal/discover/discover.go
Normal file
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
// DevicePath is a path in /dev associated with a device
|
||||
type DevicePath string
|
||||
|
||||
// ProcPath is a path in /proc associated with a devices
|
||||
type ProcPath string
|
||||
|
||||
// PCIBusID is the ID on the PCI bus of a device
|
||||
type PCIBusID string
|
||||
|
||||
// DeviceNode represents a device on the file system
|
||||
type DeviceNode struct {
|
||||
Path DevicePath
|
||||
Major int
|
||||
Minor int
|
||||
}
|
||||
|
||||
// Device represents a discovered device including identifiers (Index, UUID, PCI bus ID)
|
||||
// for selection and paths in /dev and /proc associated with the device
|
||||
type Device struct {
|
||||
Index string
|
||||
UUID string
|
||||
PCIBusID PCIBusID
|
||||
DeviceNodes []DeviceNode
|
||||
ProcPaths []ProcPath
|
||||
}
|
||||
|
||||
// Mount represents a discovered mount. This includes a set of labels
|
||||
// for selection and the mount path
|
||||
type Mount struct {
|
||||
Path string
|
||||
Labels map[string]string
|
||||
}
|
||||
|
||||
// Hook represents a discovered hook
|
||||
type Hook struct {
|
||||
Path string
|
||||
Args []string
|
||||
HookName string
|
||||
Labels map[string]string
|
||||
}
|
||||
|
||||
// Discover defines an interface for discovering the devices and mounts available on a system
|
||||
type Discover interface {
|
||||
Devices() ([]Device, error)
|
||||
Mounts() ([]Mount, error)
|
||||
Hooks() ([]Hook, error)
|
||||
}
|
||||
424
internal/discover/discover_nvml.go
Normal file
424
internal/discover/discover_nvml.go
Normal file
@@ -0,0 +1,424 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/proc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// ControlDeviceUUID is used as the UUID for control devices such as nvidiactl or nvidia-modeset
|
||||
ControlDeviceUUID = "CONTROL"
|
||||
|
||||
// MIGConfigDeviceUUID is used to indicate the MIG config control device
|
||||
MIGConfigDeviceUUID = "CONFIG"
|
||||
|
||||
// MIGMonitorDeviceUUID is used to indicate the MIG monitor control device
|
||||
MIGMonitorDeviceUUID = "MONITOR"
|
||||
|
||||
nvidiaGPUDeviceName = "nvidia-frontend"
|
||||
nvidiaCapsDeviceName = "nvidia-caps"
|
||||
nvidiaUVMDeviceName = "nvidia-uvm"
|
||||
)
|
||||
|
||||
type nvmlDiscover struct {
|
||||
None
|
||||
logger *log.Logger
|
||||
nvml nvml.Interface
|
||||
migCaps map[ProcPath]DeviceNode
|
||||
nvidiaDevices proc.NvidiaDevices
|
||||
}
|
||||
|
||||
var _ Discover = (*nvmlDiscover)(nil)
|
||||
|
||||
// NewNVMLDiscover constructs a discoverer that uses NVML to find the devices
|
||||
// available on a system.
|
||||
func NewNVMLDiscover(nvml nvml.Interface) (Discover, error) {
|
||||
return NewNVMLDiscoverWithLogger(log.StandardLogger(), nvml)
|
||||
}
|
||||
|
||||
// NewNVMLDiscoverWithLogger constructs a discovered as with NewNVMLDiscover with the specified
|
||||
// logger
|
||||
func NewNVMLDiscoverWithLogger(logger *log.Logger, nvml nvml.Interface) (Discover, error) {
|
||||
nvidiaDevices, err := proc.GetNvidiaDevices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading NVIDIA devices: %v", err)
|
||||
}
|
||||
|
||||
var migCaps map[ProcPath]DeviceNode
|
||||
nvcapsDevice, exists := nvidiaDevices.Get(nvidiaCapsDeviceName)
|
||||
if !exists {
|
||||
logger.Warnf("%v nvcaps device could not be found", nvidiaCapsDeviceName)
|
||||
} else if migCaps, err = getMigCaps(nvcapsDevice.Major); err != nil {
|
||||
logger.Warnf("Could not load MIG capability devices: %v", err)
|
||||
migCaps = nil
|
||||
}
|
||||
|
||||
discover := &nvmlDiscover{
|
||||
logger: logger,
|
||||
nvml: nvml,
|
||||
migCaps: migCaps,
|
||||
nvidiaDevices: nvidiaDevices,
|
||||
}
|
||||
|
||||
return discover, nil
|
||||
}
|
||||
|
||||
// hasMigSupport checks if MIG device discovery is supported.
|
||||
// Cases where this will be disabled include where no MIG minors file is
|
||||
// present.
|
||||
func (d nvmlDiscover) hasMigSupport() bool {
|
||||
return len(d.migCaps) > 0
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) Devices() ([]Device, error) {
|
||||
ret := d.nvml.Init()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error initalizing NVML: %v", ret.Error())
|
||||
}
|
||||
defer d.tryShutdownNVML()
|
||||
|
||||
c, ret := d.nvml.DeviceGetCount()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting device count: %v", ret.Error())
|
||||
}
|
||||
|
||||
var handles []nvml.Device
|
||||
for i := 0; i < c; i++ {
|
||||
handle, ret := d.nvml.DeviceGetHandleByIndex(i)
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting device handle for device %v: %v", i, ret.Error())
|
||||
}
|
||||
|
||||
if !d.hasMigSupport() {
|
||||
handles = append(handles, handle)
|
||||
continue
|
||||
}
|
||||
|
||||
migHandles, err := getMIGHandlesForDevice(handle)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting MIG handles for device: %v", err)
|
||||
}
|
||||
if len(migHandles) == 0 {
|
||||
handles = append(handles, handle)
|
||||
}
|
||||
handles = append(handles, migHandles...)
|
||||
}
|
||||
|
||||
return d.devicesByHandle(handles)
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) devicesByHandle(handles []nvml.Device) ([]Device, error) {
|
||||
var devices []Device
|
||||
var largestMinorNumber int
|
||||
for _, h := range handles {
|
||||
device, err := d.deviceFromNVMLHandle(h)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing device from handle %v: %v", h, err)
|
||||
}
|
||||
devices = append(devices, device)
|
||||
|
||||
if largestMinorNumber < device.DeviceNodes[0].Minor {
|
||||
largestMinorNumber = device.DeviceNodes[0].Minor
|
||||
}
|
||||
}
|
||||
|
||||
controlDevices, err := d.getControlDevices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting control devices: %v", err)
|
||||
}
|
||||
devices = append(devices, controlDevices...)
|
||||
|
||||
if d.hasMigSupport() {
|
||||
migControlDevices, err := d.getMigControlDevices(largestMinorNumber)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting MIG control devices: %v", err)
|
||||
}
|
||||
devices = append(devices, migControlDevices...)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) deviceFromNVMLHandle(handle nvml.Device) (Device, error) {
|
||||
if d.hasMigSupport() {
|
||||
isMigDevice, ret := handle.IsMigDeviceHandle()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error checking device handle: %v", ret.Error())
|
||||
}
|
||||
|
||||
if isMigDevice {
|
||||
return d.deviceFromMIGDeviceHandle(handle)
|
||||
}
|
||||
}
|
||||
|
||||
return d.deviceFromFullDeviceHandle(handle)
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) deviceFromFullDeviceHandle(handle nvml.Device) (Device, error) {
|
||||
index, ret := handle.GetIndex()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
|
||||
}
|
||||
|
||||
uuid, ret := handle.GetUUID()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
|
||||
}
|
||||
|
||||
pciInfo, ret := handle.GetPciInfo()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting PCI info: %v", ret.Error())
|
||||
}
|
||||
busID := NewPCIBusID(pciInfo)
|
||||
|
||||
minor, ret := handle.GetMinorNumber()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting minor number: %v", ret.Error())
|
||||
}
|
||||
|
||||
nvidiaGPUDevice, exists := d.nvidiaDevices.Get(nvidiaGPUDeviceName)
|
||||
if !exists {
|
||||
return Device{}, fmt.Errorf("device for '%v' does not exist", nvidiaGPUDeviceName)
|
||||
}
|
||||
|
||||
deviceNode := DeviceNode{
|
||||
Path: DevicePath(fmt.Sprintf("/dev/nvidia%d", minor)),
|
||||
Major: nvidiaGPUDevice.Major,
|
||||
Minor: minor,
|
||||
}
|
||||
|
||||
device := Device{
|
||||
Index: fmt.Sprintf("%d", index),
|
||||
PCIBusID: busID,
|
||||
UUID: uuid,
|
||||
DeviceNodes: []DeviceNode{deviceNode},
|
||||
ProcPaths: []ProcPath{busID.GetProcPath()},
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) deviceFromMIGDeviceHandle(handle nvml.Device) (Device, error) {
|
||||
parent, ret := handle.GetDeviceHandleFromMigDeviceHandle()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting parent device handle: %v", ret.Error())
|
||||
}
|
||||
|
||||
gpu, ret := parent.GetMinorNumber()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting GPU minor number: %v", ret.Error())
|
||||
}
|
||||
|
||||
parentDevice, err := d.deviceFromFullDeviceHandle(parent)
|
||||
if err != nil {
|
||||
return Device{}, fmt.Errorf("error getting parent device: %v", err)
|
||||
}
|
||||
|
||||
index, ret := handle.GetIndex()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
|
||||
}
|
||||
|
||||
uuid, ret := handle.GetUUID()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
|
||||
}
|
||||
|
||||
capDeviceNodes := []DeviceNode{}
|
||||
procPaths, err := getProcPathsForMigDevice(gpu, handle)
|
||||
if err != nil {
|
||||
return Device{}, fmt.Errorf("error getting proc paths for MIG device: %v", err)
|
||||
}
|
||||
|
||||
for _, p := range procPaths {
|
||||
capDeviceNode, ok := d.migCaps[p]
|
||||
if !ok {
|
||||
return Device{}, fmt.Errorf("could not determine cap device path for %v", p)
|
||||
}
|
||||
capDeviceNodes = append(capDeviceNodes, capDeviceNode)
|
||||
}
|
||||
|
||||
device := Device{
|
||||
Index: fmt.Sprintf("%s:%d", parentDevice.Index, index),
|
||||
UUID: uuid,
|
||||
DeviceNodes: append(parentDevice.DeviceNodes, capDeviceNodes...),
|
||||
ProcPaths: append(parentDevice.ProcPaths, procPaths...),
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) getControlDevices() ([]Device, error) {
|
||||
devices := []struct {
|
||||
name string
|
||||
path string
|
||||
minor int
|
||||
}{
|
||||
// TODO: Where is the best place to find these device Minors programatically?
|
||||
{nvidiaGPUDeviceName, "/dev/nvidia-modeset", 254},
|
||||
{nvidiaGPUDeviceName, "/dev/nvidiactl", 255},
|
||||
{nvidiaUVMDeviceName, "/dev/nvidia-uvm", 0},
|
||||
{nvidiaUVMDeviceName, "/dev/nvidia-uvm-tools", 1},
|
||||
}
|
||||
|
||||
var controlDevices []Device
|
||||
for _, info := range devices {
|
||||
device, exists := d.nvidiaDevices.Get(info.name)
|
||||
if !exists {
|
||||
d.logger.Warnf("device name %v not defined; skipping control devices %v", info.name, info.path)
|
||||
continue
|
||||
}
|
||||
|
||||
deviceNode := DeviceNode{
|
||||
Path: DevicePath(info.path),
|
||||
Major: device.Major,
|
||||
Minor: info.minor,
|
||||
}
|
||||
|
||||
controlDevices = append(controlDevices, Device{
|
||||
UUID: ControlDeviceUUID,
|
||||
DeviceNodes: []DeviceNode{deviceNode},
|
||||
ProcPaths: []ProcPath{},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return controlDevices, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) getMigControlDevices(numGpus int) ([]Device, error) {
|
||||
targets := map[string]ProcPath{
|
||||
MIGConfigDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/config"),
|
||||
MIGMonitorDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"),
|
||||
}
|
||||
|
||||
var devices []Device
|
||||
for id, procPath := range targets {
|
||||
deviceNode, exists := d.migCaps[procPath]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("device node for '%v' is undefined", procPath)
|
||||
}
|
||||
|
||||
var procPaths []ProcPath
|
||||
for gpu := 0; gpu <= numGpus; gpu++ {
|
||||
procPaths = append(procPaths, ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig", gpu)))
|
||||
}
|
||||
procPaths = append(procPaths, procPath)
|
||||
|
||||
devices = append(devices, Device{
|
||||
UUID: id,
|
||||
DeviceNodes: []DeviceNode{deviceNode},
|
||||
ProcPaths: procPaths,
|
||||
})
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func getProcPathsForMigDevice(gpu int, handle nvml.Device) ([]ProcPath, error) {
|
||||
gi, ret := handle.GetGPUInstanceId()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting GPU instance ID: %v", ret.Error())
|
||||
}
|
||||
|
||||
ci, ret := handle.GetComputeInstanceId()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting comput instance ID: %v", ret.Error())
|
||||
}
|
||||
|
||||
procPaths := []ProcPath{
|
||||
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/access", gpu, gi)),
|
||||
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/ci%d/access", gpu, gi, ci)),
|
||||
}
|
||||
|
||||
return procPaths, nil
|
||||
}
|
||||
|
||||
func getMIGHandlesForDevice(handle nvml.Device) ([]nvml.Device, error) {
|
||||
currentMigMode, _, ret := handle.GetMigMode()
|
||||
if ret.Value() == nvml.ERROR_NOT_SUPPORTED {
|
||||
return nil, nil
|
||||
}
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting MIG mode for device: %v", ret.Error())
|
||||
}
|
||||
if currentMigMode == nvml.DEVICE_MIG_DISABLE {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
maxMigDeviceCount, ret := handle.GetMaxMigDeviceCount()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting number of MIG devices: %v", ret.Error())
|
||||
}
|
||||
|
||||
var migHandles []nvml.Device
|
||||
for mi := 0; mi < maxMigDeviceCount; mi++ {
|
||||
migHandle, ret := handle.GetMigDeviceHandleByIndex(mi)
|
||||
if ret.Value() == nvml.ERROR_NOT_FOUND {
|
||||
continue
|
||||
}
|
||||
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting MIG device %v: %v", mi, ret.Error())
|
||||
}
|
||||
|
||||
migHandles = append(migHandles, migHandle)
|
||||
}
|
||||
|
||||
return migHandles, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) tryShutdownNVML() {
|
||||
ret := d.nvml.Shutdown()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
d.logger.Warnf("Could not shut down NVML: %v", ret.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// NewPCIBusID provides a utility function that returns the string representation
|
||||
// of the bus ID.
|
||||
func NewPCIBusID(p nvml.PciInfo) PCIBusID {
|
||||
var bytes []byte
|
||||
for _, b := range p.BusId {
|
||||
if byte(b) == '\x00' {
|
||||
break
|
||||
}
|
||||
bytes = append(bytes, byte(b))
|
||||
}
|
||||
return PCIBusID(string(bytes))
|
||||
}
|
||||
|
||||
// GetProcPath returns the path in /proc associated with the PCI bus ID
|
||||
func (p PCIBusID) GetProcPath() ProcPath {
|
||||
id := strings.ToLower(p.String())
|
||||
|
||||
if strings.HasPrefix(id, "0000") {
|
||||
id = id[4:]
|
||||
}
|
||||
return ProcPath(filepath.Join("/proc/driver/nvidia/gpus", id))
|
||||
}
|
||||
|
||||
func (p PCIBusID) String() string {
|
||||
return string(p)
|
||||
}
|
||||
44
internal/discover/discover_nvml_mig.go
Normal file
44
internal/discover/discover_nvml_mig.go
Normal file
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
|
||||
)
|
||||
|
||||
// getMigCaps returns a mapping of MIG capability path to device nodes
|
||||
func getMigCaps(capDeviceMajor int) (map[ProcPath]DeviceNode, error) {
|
||||
migCaps, err := nvcaps.LoadMigMinors()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading MIG minors: %v", err)
|
||||
}
|
||||
return getMigCapsFromMigMinors(migCaps, capDeviceMajor), nil
|
||||
}
|
||||
|
||||
func getMigCapsFromMigMinors(migCaps map[nvcaps.MigCap]nvcaps.MigMinor, capDeviceMajor int) map[ProcPath]DeviceNode {
|
||||
capsDevicePaths := make(map[ProcPath]DeviceNode)
|
||||
for cap, minor := range migCaps {
|
||||
capsDevicePaths[ProcPath(cap.ProcPath())] = DeviceNode{
|
||||
Path: DevicePath(minor.DevicePath()),
|
||||
Major: capDeviceMajor,
|
||||
Minor: int(minor),
|
||||
}
|
||||
}
|
||||
return capsDevicePaths
|
||||
}
|
||||
41
internal/discover/discover_nvml_mig_test.go
Normal file
41
internal/discover/discover_nvml_mig_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetMigCaps(t *testing.T) {
|
||||
migMinors := map[nvcaps.MigCap]nvcaps.MigMinor{
|
||||
"config": 1,
|
||||
"monitor": 2,
|
||||
"gpu0/gi0/access": 3,
|
||||
"gpu0/gi0/ci0/access": 4,
|
||||
}
|
||||
|
||||
migCapMajor := 999
|
||||
migCaps := getMigCapsFromMigMinors(migMinors, migCapMajor)
|
||||
|
||||
require.Len(t, migCaps, len(migMinors))
|
||||
for _, c := range migCaps {
|
||||
require.Equal(t, migCapMajor, c.Major)
|
||||
}
|
||||
}
|
||||
219
internal/discover/discover_nvml_test.go
Normal file
219
internal/discover/discover_nvml_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/proc"
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
nvidiaGPUDeviceMajorDefault = 195
|
||||
nvidiaCapsDeviceMajorDefault = 235
|
||||
)
|
||||
|
||||
func newTestDiscover() nvmlDiscover {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
nvml := nvml.NewMockNVMLServer(nvml.NewMockA100Device(0))
|
||||
|
||||
return nvmlDiscover{
|
||||
logger: logger,
|
||||
nvml: nvml,
|
||||
nvidiaDevices: proc.NewMockNvidiaDevices(
|
||||
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func newTestDiscoverWithMIG() nvmlDiscover {
|
||||
device := &nvml.MockA100Device{
|
||||
Index: 0,
|
||||
MigMode: nvml.DEVICE_MIG_ENABLE,
|
||||
GpuInstances: make(map[*nvml.MockA100GpuInstance]struct{}),
|
||||
GpuInstanceCounter: 0,
|
||||
}
|
||||
// Create a single gi and ci on the device
|
||||
gpuInstanceProfileInfo := &nvml.GpuInstanceProfileInfo{
|
||||
Id: nvml.GPU_INSTANCE_PROFILE_7_SLICE,
|
||||
}
|
||||
gi, _ := device.CreateGpuInstance(gpuInstanceProfileInfo)
|
||||
|
||||
computeInstanceProfileInfo := &nvml.ComputeInstanceProfileInfo{
|
||||
Id: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE,
|
||||
}
|
||||
_, _ = gi.CreateComputeInstance(computeInstanceProfileInfo)
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
return nvmlDiscover{
|
||||
logger: logger,
|
||||
nvml: nvml.NewMockNVMLServer(device),
|
||||
migCaps: map[ProcPath]DeviceNode{
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
|
||||
Minor: 3,
|
||||
Major: 235,
|
||||
},
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
|
||||
Minor: 4,
|
||||
Major: 235,
|
||||
},
|
||||
ProcPath("/proc/driver/nvidia/capabilities/mig/config"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
|
||||
Minor: 1,
|
||||
Major: 235,
|
||||
},
|
||||
ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
|
||||
Minor: 2,
|
||||
Major: 235,
|
||||
},
|
||||
},
|
||||
nvidiaDevices: proc.NewMockNvidiaDevices(
|
||||
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
|
||||
proc.Device{Name: nvidiaCapsDeviceName, Major: nvidiaCapsDeviceMajorDefault},
|
||||
),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDiscoverNvmlDevices(t *testing.T) {
|
||||
d := newTestDiscover()
|
||||
devices, err := d.Devices()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, devices, 3)
|
||||
|
||||
device := devices[0]
|
||||
|
||||
require.Equal(t, "0", device.Index)
|
||||
require.Equal(t, "GPU-0", device.UUID)
|
||||
require.Equal(t, "0000FFFF:FF:FF.F", device.PCIBusID.String())
|
||||
|
||||
expectedDeviceNodes := []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia0"),
|
||||
Minor: 0,
|
||||
Major: nvidiaGPUDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, device.DeviceNodes)
|
||||
|
||||
expectedProcPaths := []ProcPath{ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f")}
|
||||
require.Equal(t, expectedProcPaths, device.ProcPaths)
|
||||
}
|
||||
|
||||
func TestDiscoverNvmlMigDevices(t *testing.T) {
|
||||
d := newTestDiscoverWithMIG()
|
||||
|
||||
devices, err := d.Devices()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, devices, 5)
|
||||
|
||||
mig := devices[0]
|
||||
|
||||
require.Equal(t, "0:0", mig.Index)
|
||||
require.Equal(t, "MIG-0", mig.UUID)
|
||||
require.Empty(t, mig.PCIBusID)
|
||||
|
||||
expectedDeviceNodes := []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia0"),
|
||||
Minor: 0,
|
||||
Major: nvidiaGPUDeviceMajorDefault,
|
||||
},
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
|
||||
Minor: 3,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
|
||||
Minor: 4,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, mig.DeviceNodes)
|
||||
|
||||
expectedProcPaths := []ProcPath{
|
||||
ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f"),
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"),
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"),
|
||||
}
|
||||
require.Equal(t, expectedProcPaths, mig.ProcPaths)
|
||||
|
||||
var config *Device
|
||||
var monitor *Device
|
||||
for i, d := range devices {
|
||||
if d.UUID == "CONFIG" {
|
||||
config = &devices[i]
|
||||
}
|
||||
if d.UUID == "MONITOR" {
|
||||
monitor = &devices[i]
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, config)
|
||||
require.NotNil(t, monitor)
|
||||
|
||||
require.Equal(t, "CONFIG", config.UUID)
|
||||
|
||||
expectedDeviceNodes = []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
|
||||
Minor: 1,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, config.DeviceNodes)
|
||||
|
||||
require.Contains(t, config.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/config"))
|
||||
require.Len(t, config.ProcPaths, 2)
|
||||
|
||||
require.Equal(t, "MONITOR", monitor.UUID)
|
||||
|
||||
expectedDeviceNodes = []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
|
||||
Minor: 2,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, monitor.DeviceNodes)
|
||||
|
||||
require.Contains(t, monitor.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"))
|
||||
require.Len(t, monitor.ProcPaths, 2)
|
||||
|
||||
}
|
||||
|
||||
func TestPCIBusID(t *testing.T) {
|
||||
testCases := map[string]ProcPath{
|
||||
"0000FFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffff:ff:ff.f",
|
||||
"FFFFFFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffffffff:ff:ff.f",
|
||||
}
|
||||
|
||||
for busID, procPath := range testCases {
|
||||
p := PCIBusID(busID)
|
||||
require.Equal(t, busID, p.String())
|
||||
require.Equal(t, procPath, p.GetProcPath())
|
||||
}
|
||||
}
|
||||
71
internal/discover/hooks.go
Normal file
71
internal/discover/hooks.go
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type hooks struct {
|
||||
None
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
var _ Discover = (*hooks)(nil)
|
||||
|
||||
// NewHooks creates a discoverer for linux containers
|
||||
func NewHooks() Discover {
|
||||
return NewHooksWithLogger(log.StandardLogger())
|
||||
}
|
||||
|
||||
// NewHooksWithLogger creates a discoverer as with NewHooks with the specified logger
|
||||
func NewHooksWithLogger(logger *log.Logger) Discover {
|
||||
h := hooks{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h hooks) Hooks() ([]Hook, error) {
|
||||
var hooks []Hook
|
||||
|
||||
hooks = append(hooks, newLdconfigHook())
|
||||
|
||||
return hooks, nil
|
||||
}
|
||||
|
||||
func newLdconfigHook() Hook {
|
||||
const rootPattern = "@Root.Path@"
|
||||
|
||||
h := Hook{
|
||||
Path: "/sbin/ldconfig",
|
||||
Args: []string{
|
||||
// TODO: Testing seems to indicate that this is -v flag is required
|
||||
"-v",
|
||||
"-r", rootPattern,
|
||||
},
|
||||
// TODO: CreateContainer hooks were only added to a later OCI spec version
|
||||
// We will have to find a way to deal with OCI versions before 1.0.2
|
||||
HookName: "create-container",
|
||||
Labels: map[string]string{
|
||||
"min-oci-version": "1.0.2",
|
||||
},
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
52
internal/discover/ipc.go
Normal file
52
internal/discover/ipc.go
Normal file
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewIPCMounts creates a discoverer for IPC sockets
|
||||
func NewIPCMounts(root string) Discover {
|
||||
return NewIPCMountsWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewIPCMountsWithLogger creates a discovered as with NewIPCMounts with the
|
||||
// specified logger.
|
||||
func NewIPCMountsWithLogger(logger *log.Logger, root string) Discover {
|
||||
d := mounts{
|
||||
logger: logger,
|
||||
lookup: lookup.NewFileLocatorWithLogger(logger, root),
|
||||
required: requiredIPCs,
|
||||
}
|
||||
|
||||
return &d
|
||||
}
|
||||
|
||||
var requiredIPCs = map[string][]string{
|
||||
"nvidia-persistenced": {
|
||||
"/var/run/nvidia-persistenced/socket",
|
||||
},
|
||||
"nvidia-fabricmanager": {
|
||||
"/var/run/nvidia-fabricmanager/socket",
|
||||
},
|
||||
// TODO: This can be controlled by the NV_MPS_PIPE_DIR envvar
|
||||
"nvidia-mps": {
|
||||
"/tmp/nvidia-mps",
|
||||
},
|
||||
}
|
||||
56
internal/discover/ipc_test.go
Normal file
56
internal/discover/ipc_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIPCDiscover(t *testing.T) {
|
||||
|
||||
ipcLookup := map[string]string{
|
||||
"/var/run/nvidia-persistenced/socket": "/var/run/nvidia-persistenced/socket",
|
||||
"/var/run/nvidia-fabricmanager/socket": "fm-socket",
|
||||
}
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
d := NewIPCMountsWithLogger(logger, "")
|
||||
|
||||
// Override lookup for testing
|
||||
mockLocator := NewLocatorMockFromMap(ipcLookup)
|
||||
d.(*mounts).lookup = mockLocator
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, []Mount{
|
||||
{Path: "/var/run/nvidia-persistenced/socket", Labels: map[string]string{}},
|
||||
{Path: "fm-socket", Labels: map[string]string{}}}, mounts)
|
||||
|
||||
require.Len(t, mockLocator.LocateCalls(), 3)
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
100
internal/discover/libraries.go
Normal file
100
internal/discover/libraries.go
Normal file
@@ -0,0 +1,100 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewLibraries constructs discoverer for libraries
|
||||
func NewLibraries(root string) (Discover, error) {
|
||||
return NewLibrariesWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewLibrariesWithLogger constructs discoverer for libraries with the specified logger
|
||||
func NewLibrariesWithLogger(logger *log.Logger, root string) (Discover, error) {
|
||||
lookup, err := lookup.NewLibraryLocatorWithLogger(logger, root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing locator: %v", err)
|
||||
}
|
||||
|
||||
d := mounts{
|
||||
logger: logger,
|
||||
lookup: lookup,
|
||||
required: requiredLibraries,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// requiredLibraries defines a set of libraries and their labels
|
||||
var requiredLibraries = map[string][]string{
|
||||
"utility": {
|
||||
"libnvidia-ml.so", /* Management library */
|
||||
"libnvidia-cfg.so", /* GPU configuration */
|
||||
},
|
||||
"compute": {
|
||||
"libcuda.so", /* CUDA driver library */
|
||||
"libnvidia-opencl.so", /* NVIDIA OpenCL ICD */
|
||||
"libnvidia-ptxjitcompiler.so", /* PTX-SASS JIT compiler (used by libcuda) */
|
||||
"libnvidia-fatbinaryloader.so", /* fatbin loader (used by libcuda) */
|
||||
"libnvidia-allocator.so", /* NVIDIA allocator runtime library */
|
||||
"libnvidia-compiler.so", /* NVVM-PTX compiler for OpenCL (used by libnvidia-opencl) */
|
||||
},
|
||||
"video": {
|
||||
"libvdpau_nvidia.so", /* NVIDIA VDPAU ICD */
|
||||
"libnvidia-encode.so", /* Video encoder */
|
||||
"libnvidia-opticalflow.so", /* NVIDIA Opticalflow library */
|
||||
"libnvcuvid.so", /* Video decoder */
|
||||
},
|
||||
"graphics": {
|
||||
//"libnvidia-egl-wayland.so", /* EGL wayland platform extension (used by libEGL_nvidia) */
|
||||
"libnvidia-eglcore.so", /* EGL core (used by libGLES*[_nvidia] and libEGL_nvidia) */
|
||||
"libnvidia-glcore.so", /* OpenGL core (used by libGL or libGLX_nvidia) */
|
||||
"libnvidia-tls.so", /* Thread local storage (used by libGL or libGLX_nvidia) */
|
||||
"libnvidia-glsi.so", /* OpenGL system interaction (used by libEGL_nvidia) */
|
||||
"libnvidia-fbc.so", /* Framebuffer capture */
|
||||
"libnvidia-ifr.so", /* OpenGL framebuffer capture */
|
||||
"libnvidia-rtcore.so", /* Optix */
|
||||
"libnvoptix.so", /* Optix */
|
||||
},
|
||||
"glvnd": {
|
||||
//"libGLX.so", /* GLX ICD loader */
|
||||
//"libOpenGL.so", /* OpenGL ICD loader */
|
||||
//"libGLdispatch.so", /* OpenGL dispatch (used by libOpenGL, libEGL and libGLES*) */
|
||||
"libGLX_nvidia.so", /* OpenGL/GLX ICD */
|
||||
"libEGL_nvidia.so", /* EGL ICD */
|
||||
"libGLESv2_nvidia.so", /* OpenGL ES v2 ICD */
|
||||
"libGLESv1_CM_nvidia.so", /* OpenGL ES v1 common profile ICD */
|
||||
"libnvidia-glvkspirv.so", /* SPIR-V Lib for Vulkan */
|
||||
"libnvidia-cbl.so", /* VK_NV_ray_tracing */
|
||||
},
|
||||
"compat32": {
|
||||
"libGL.so", /* OpenGL/GLX legacy _or_ compatibility wrapper (GLVND) */
|
||||
"libEGL.so", /* EGL legacy _or_ ICD loader (GLVND) */
|
||||
"libGLESv1_CM.so", /* OpenGL ES v1 common profile legacy _or_ ICD loader (GLVND) */
|
||||
"libGLESv2.so", /* OpenGL ES v2 legacy _or_ ICD loader (GLVND) */
|
||||
},
|
||||
"ngx": {
|
||||
"libnvidia-ngx.so", /* NGX library */
|
||||
},
|
||||
"dxcore": {
|
||||
"libdxcore.so", /* Core library for dxcore support */
|
||||
},
|
||||
}
|
||||
56
internal/discover/libraries_test.go
Normal file
56
internal/discover/libraries_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLibraries(t *testing.T) {
|
||||
|
||||
libraryLookup := map[string]string{
|
||||
"libcuda.so": "/lib/libcuda.so.999.99",
|
||||
"libversion.so": "/lib/libversion.so.111.11",
|
||||
"libother.so": "/lib/libother.so.999.99",
|
||||
}
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
d, err := NewLibrariesWithLogger(logger, "")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Override lookup for testing
|
||||
mockLocator := NewLocatorMockFromMap(libraryLookup)
|
||||
d.(*mounts).lookup = mockLocator
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.ElementsMatch(t, []Mount{{Path: "/lib/libcuda.so.999.99", Labels: map[string]string{}}}, mounts)
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
125
internal/discover/mounts.go
Normal file
125
internal/discover/mounts.go
Normal file
@@ -0,0 +1,125 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
capabilityLabel = "capability"
|
||||
versionLabel = "version"
|
||||
)
|
||||
|
||||
// mounts is a generic discoverer for Mounts. It is customized by specifying the
|
||||
// required entities as a key-value pair as well as a Locator that is used to
|
||||
// identify the mounts that are to be included.
|
||||
type mounts struct {
|
||||
None
|
||||
logger *log.Logger
|
||||
lookup lookup.Locator
|
||||
required map[string][]string
|
||||
}
|
||||
|
||||
var _ Discover = (*mounts)(nil)
|
||||
|
||||
func (d mounts) Mounts() ([]Mount, error) {
|
||||
mounts, err := d.uniqueMounts()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering mounts: %v", err)
|
||||
}
|
||||
|
||||
return mounts.Slice(), nil
|
||||
}
|
||||
|
||||
func (d mounts) uniqueMounts() (mountsByPath, error) {
|
||||
if d.lookup == nil {
|
||||
return nil, fmt.Errorf("no lookup defined")
|
||||
}
|
||||
|
||||
mounts := make(mountsByPath)
|
||||
|
||||
for id, keys := range d.required {
|
||||
for _, key := range keys {
|
||||
d.logger.Debugf("Locating %v [%v]", key, id)
|
||||
located, err := d.lookup.Locate(key)
|
||||
if err != nil {
|
||||
d.logger.Warnf("Could not locate %v [%v]: %v", key, id, err)
|
||||
continue
|
||||
}
|
||||
d.logger.Infof("Located %v [%v]: %v", key, id, located)
|
||||
for _, p := range located {
|
||||
// TODO: We need to add labels
|
||||
mount := newMount(p)
|
||||
mounts.Put(mount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mounts, nil
|
||||
}
|
||||
|
||||
type mountsByPath map[string]Mount
|
||||
|
||||
func (m mountsByPath) Slice() []Mount {
|
||||
var mounts []Mount
|
||||
for _, mount := range m {
|
||||
mounts = append(mounts, mount)
|
||||
}
|
||||
|
||||
return mounts
|
||||
}
|
||||
|
||||
func (m *mountsByPath) Put(value Mount) {
|
||||
key := value.Path
|
||||
mount, exists := (*m)[key]
|
||||
if !exists {
|
||||
(*m)[key] = value
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range value.Labels {
|
||||
mount.Labels[k] = v
|
||||
}
|
||||
(*m)[key] = mount
|
||||
}
|
||||
|
||||
// NewMountForCapability creates a mount with the specified capability label
|
||||
func NewMountForCapability(path string, capability string) Mount {
|
||||
return newMount(path, capabilityLabel, capability)
|
||||
}
|
||||
|
||||
// NewMountForVersion creates a mount with the specified version label
|
||||
func NewMountForVersion(path string, version string) Mount {
|
||||
return newMount(path, versionLabel, version)
|
||||
}
|
||||
|
||||
func newMount(path string, labels ...string) Mount {
|
||||
l := make(map[string]string)
|
||||
|
||||
for i := 0; i < len(labels)-1; i += 2 {
|
||||
l[labels[i]] = labels[i+1]
|
||||
}
|
||||
|
||||
return Mount{
|
||||
Path: path,
|
||||
Labels: l,
|
||||
}
|
||||
}
|
||||
53
internal/discover/mounts_test.go
Normal file
53
internal/discover/mounts_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMountsReturnsErrorForNoLookup(t *testing.T) {
|
||||
d := mounts{}
|
||||
mounts, err := d.Mounts()
|
||||
|
||||
require.Error(t, err)
|
||||
require.Len(t, mounts, 0)
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
|
||||
func NewLocatorMockFromMap(lookupMap map[string]string) *lookup.LocatorMock {
|
||||
return &lookup.LocatorMock{
|
||||
LocateFunc: func(key string) ([]string, error) {
|
||||
value, exists := lookupMap[key]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("key %v not found", key)
|
||||
}
|
||||
return []string{value}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
38
internal/discover/none.go
Normal file
38
internal/discover/none.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
// None is a null discoverer that returns an empty list of devices and
|
||||
// mounts.
|
||||
type None struct{}
|
||||
|
||||
var _ Discover = (*None)(nil)
|
||||
|
||||
// Devices returns an empty list of devices
|
||||
func (e None) Devices() ([]Device, error) {
|
||||
return []Device{}, nil
|
||||
}
|
||||
|
||||
// Mounts returns an empty list of mounts
|
||||
func (e None) Mounts() ([]Mount, error) {
|
||||
return []Mount{}, nil
|
||||
}
|
||||
|
||||
// Hooks returns an empty list of hooks
|
||||
func (e None) Hooks() ([]Hook, error) {
|
||||
return []Hook{}, nil
|
||||
}
|
||||
39
internal/discover/none_test.go
Normal file
39
internal/discover/none_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNone(t *testing.T) {
|
||||
d := None{}
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, mounts)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
71
internal/discover/nvml_server.go
Normal file
71
internal/discover/nvml_server.go
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type nvmlServer struct {
|
||||
logger *log.Logger
|
||||
composite
|
||||
}
|
||||
|
||||
var _ Discover = (*nvmlServer)(nil)
|
||||
|
||||
// NewNVMLServer constructs a discoverer for server systems using NVML to discover devices
|
||||
func NewNVMLServer(root string) (Discover, error) {
|
||||
return NewNVMLServerWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewNVMLServerWithLogger constructs a discoverer for server systems using NVML to discover devices with
|
||||
// the specified logger
|
||||
func NewNVMLServerWithLogger(logger *log.Logger, root string) (Discover, error) {
|
||||
return createNVMLServer(logger, nvml.New(), root)
|
||||
}
|
||||
|
||||
func createNVMLServer(logger *log.Logger, nvml nvml.Interface, root string) (Discover, error) {
|
||||
d := nvmlServer{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
devices, err := NewNVMLDiscoverWithLogger(logger, nvml)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing NVML device discoverer: %v", err)
|
||||
}
|
||||
|
||||
libraries, err := NewLibrariesWithLogger(logger, root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing library discoverer: %v", err)
|
||||
}
|
||||
|
||||
d.add(
|
||||
// Device discovery
|
||||
devices,
|
||||
// Mounts discovery
|
||||
libraries,
|
||||
NewBinaryMountsWithLogger(logger, root),
|
||||
NewIPCMountsWithLogger(logger, root),
|
||||
// Hook discovery
|
||||
NewHooksWithLogger(logger),
|
||||
)
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
66
internal/discover/nvml_server_test.go
Normal file
66
internal/discover/nvml_server_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/proc"
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testMajor = 999
|
||||
)
|
||||
|
||||
func TestNVMLServerConstructor(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
nvml := nvml.NewMockNVMLOnLunaServer()
|
||||
|
||||
d, err := createNVMLServer(logger, nvml, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
instance := d.(*nvmlServer)
|
||||
require.Len(t, instance.discoverers, 1+3+1)
|
||||
|
||||
// We need to override the nvidiaDevices member of the nvmlDiscovery
|
||||
// TODO: Use a mock instead, or allow for injection into a constructor
|
||||
instance.discoverers[0].(*nvmlDiscover).nvidiaDevices = &mockNvidiaDevices{}
|
||||
|
||||
devices, err := d.Devices()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, devices)
|
||||
|
||||
_, err = d.Mounts()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type mockNvidiaDevices struct{}
|
||||
|
||||
var _ proc.NvidiaDevices = (*mockNvidiaDevices)(nil)
|
||||
|
||||
func (d mockNvidiaDevices) Get(name string) (proc.Device, bool) {
|
||||
return proc.Device{Name: name, Major: testMajor}, true
|
||||
}
|
||||
|
||||
func (d mockNvidiaDevices) Exists(string) bool {
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user