mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-06-26 18:18:24 +00:00
Copy files from nvidia-container-toolkit
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
51
pkg/discover/binaries.go
Normal file
51
pkg/discover/binaries.go
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
|
||||
)
|
||||
|
||||
// NewBinaryMounts creates a discoverer for binaries using the specified root
|
||||
func NewBinaryMounts(root string) Discover {
|
||||
return NewBinaryMountsWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewBinaryMountsWithLogger creates a Mounts discoverer as with NewBinaryMounts
|
||||
// with the specified logger
|
||||
func NewBinaryMountsWithLogger(logger *log.Logger, root string) Discover {
|
||||
d := mounts{
|
||||
logger: logger,
|
||||
lookup: lookup.NewPathLocatorWithLogger(logger, root),
|
||||
required: requiredBinaries,
|
||||
}
|
||||
return &d
|
||||
}
|
||||
|
||||
// requiredBinaries defines a set of binaries and their labels
|
||||
var requiredBinaries = map[string][]string{
|
||||
"utility": {
|
||||
"nvidia-smi", /* System management interface */
|
||||
"nvidia-debugdump", /* GPU coredump utility */
|
||||
"nvidia-persistenced", /* Persistence mode utility */
|
||||
},
|
||||
"compute": {
|
||||
"nvidia-cuda-mps-control", /* Multi process service CLI */
|
||||
"nvidia-cuda-mps-server", /* Multi process service server */
|
||||
},
|
||||
}
|
||||
73
pkg/discover/binaries_test.go
Normal file
73
pkg/discover/binaries_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBinaries(t *testing.T) {
|
||||
binaryLookup := map[string]string{
|
||||
"nvidia-smi": "/usr/bin/nvidia-smi",
|
||||
"nvidia-persistenced": "/usr/bin/nvidia-persistenced",
|
||||
"nvidia-debugdump": "test-duplicates",
|
||||
"nvidia-cuda-mps-control": "test-duplicates",
|
||||
}
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
d := NewBinaryMountsWithLogger(logger, "")
|
||||
|
||||
// Override lookup for testing
|
||||
mockLocator := NewLocatorMockFromMap(binaryLookup)
|
||||
d.(*mounts).lookup = mockLocator
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := []Mount{
|
||||
{
|
||||
Path: "/usr/bin/nvidia-smi",
|
||||
},
|
||||
{
|
||||
Path: "/usr/bin/nvidia-persistenced",
|
||||
},
|
||||
{
|
||||
Path: "test-duplicates",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, len(expected), len(mounts))
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
|
||||
func TestNewBinariesConstructor(t *testing.T) {
|
||||
b := NewBinaryMounts("").(*mounts)
|
||||
|
||||
require.NotNil(t, b.logger)
|
||||
require.NotNil(t, b.lookup)
|
||||
}
|
||||
74
pkg/discover/composite.go
Normal file
74
pkg/discover/composite.go
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import "fmt"
|
||||
|
||||
// composite is a discoverer that contains a list of Discoverers. The output of the
|
||||
// Devices, Mounts, and Hooks functions is the concatenation of the output for each of the
|
||||
// elements in the list.
|
||||
type composite struct {
|
||||
discoverers []Discover
|
||||
}
|
||||
|
||||
var _ Discover = (*composite)(nil)
|
||||
|
||||
func (d composite) Devices() ([]Device, error) {
|
||||
var allDevices []Device
|
||||
|
||||
for i, di := range d.discoverers {
|
||||
devices, err := di.Devices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering devices for discoverer %v: %v", i, err)
|
||||
}
|
||||
allDevices = append(allDevices, devices...)
|
||||
}
|
||||
|
||||
return allDevices, nil
|
||||
}
|
||||
|
||||
func (d composite) Mounts() ([]Mount, error) {
|
||||
var allMounts []Mount
|
||||
|
||||
for i, di := range d.discoverers {
|
||||
mounts, err := di.Mounts()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering mounts for discoverer %v: %v", i, err)
|
||||
}
|
||||
allMounts = append(allMounts, mounts...)
|
||||
}
|
||||
|
||||
return allMounts, nil
|
||||
}
|
||||
|
||||
func (d composite) Hooks() ([]Hook, error) {
|
||||
var allHooks []Hook
|
||||
|
||||
for i, di := range d.discoverers {
|
||||
hooks, err := di.Hooks()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering hooks for discoverer %v: %v", i, err)
|
||||
}
|
||||
allHooks = append(allHooks, hooks...)
|
||||
}
|
||||
|
||||
return allHooks, nil
|
||||
}
|
||||
|
||||
func (d *composite) add(di ...Discover) {
|
||||
d.discoverers = append(d.discoverers, di...)
|
||||
}
|
||||
65
pkg/discover/discover.go
Normal file
65
pkg/discover/discover.go
Normal file
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
// DevicePath is a path in /dev associated with a device
|
||||
type DevicePath string
|
||||
|
||||
// ProcPath is a path in /proc associated with a devices
|
||||
type ProcPath string
|
||||
|
||||
// PCIBusID is the ID on the PCI bus of a device
|
||||
type PCIBusID string
|
||||
|
||||
// DeviceNode represents a device on the file system
|
||||
type DeviceNode struct {
|
||||
Path DevicePath
|
||||
Major int
|
||||
Minor int
|
||||
}
|
||||
|
||||
// Device represents a discovered device including identifiers (Index, UUID, PCI bus ID)
|
||||
// for selection and paths in /dev and /proc associated with the device
|
||||
type Device struct {
|
||||
Index string
|
||||
UUID string
|
||||
PCIBusID PCIBusID
|
||||
DeviceNodes []DeviceNode
|
||||
ProcPaths []ProcPath
|
||||
}
|
||||
|
||||
// Mount represents a discovered mount. This includes a set of labels
|
||||
// for selection and the mount path
|
||||
type Mount struct {
|
||||
Path string
|
||||
Labels map[string]string
|
||||
}
|
||||
|
||||
// Hook represents a discovered hook
|
||||
type Hook struct {
|
||||
Path string
|
||||
Args []string
|
||||
HookName string
|
||||
Labels map[string]string
|
||||
}
|
||||
|
||||
// Discover defines an interface for discovering the devices and mounts available on a system
|
||||
type Discover interface {
|
||||
Devices() ([]Device, error)
|
||||
Mounts() ([]Mount, error)
|
||||
Hooks() ([]Hook, error)
|
||||
}
|
||||
424
pkg/discover/discover_nvml.go
Normal file
424
pkg/discover/discover_nvml.go
Normal file
@@ -0,0 +1,424 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/proc"
|
||||
)
|
||||
|
||||
const (
|
||||
// ControlDeviceUUID is used as the UUID for control devices such as nvidiactl or nvidia-modeset
|
||||
ControlDeviceUUID = "CONTROL"
|
||||
|
||||
// MIGConfigDeviceUUID is used to indicate the MIG config control device
|
||||
MIGConfigDeviceUUID = "CONFIG"
|
||||
|
||||
// MIGMonitorDeviceUUID is used to indicate the MIG monitor control device
|
||||
MIGMonitorDeviceUUID = "MONITOR"
|
||||
|
||||
nvidiaGPUDeviceName = "nvidia-frontend"
|
||||
nvidiaCapsDeviceName = "nvidia-caps"
|
||||
nvidiaUVMDeviceName = "nvidia-uvm"
|
||||
)
|
||||
|
||||
type nvmlDiscover struct {
|
||||
None
|
||||
logger *log.Logger
|
||||
nvml nvml.Interface
|
||||
migCaps map[ProcPath]DeviceNode
|
||||
nvidiaDevices proc.NvidiaDevices
|
||||
}
|
||||
|
||||
var _ Discover = (*nvmlDiscover)(nil)
|
||||
|
||||
// NewNVMLDiscover constructs a discoverer that uses NVML to find the devices
|
||||
// available on a system.
|
||||
func NewNVMLDiscover(nvml nvml.Interface) (Discover, error) {
|
||||
return NewNVMLDiscoverWithLogger(log.StandardLogger(), nvml)
|
||||
}
|
||||
|
||||
// NewNVMLDiscoverWithLogger constructs a discovered as with NewNVMLDiscover with the specified
|
||||
// logger
|
||||
func NewNVMLDiscoverWithLogger(logger *log.Logger, nvml nvml.Interface) (Discover, error) {
|
||||
nvidiaDevices, err := proc.GetNvidiaDevices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading NVIDIA devices: %v", err)
|
||||
}
|
||||
|
||||
var migCaps map[ProcPath]DeviceNode
|
||||
nvcapsDevice, exists := nvidiaDevices.Get(nvidiaCapsDeviceName)
|
||||
if !exists {
|
||||
logger.Warnf("%v nvcaps device could not be found", nvidiaCapsDeviceName)
|
||||
} else if migCaps, err = getMigCaps(nvcapsDevice.Major); err != nil {
|
||||
logger.Warnf("Could not load MIG capability devices: %v", err)
|
||||
migCaps = nil
|
||||
}
|
||||
|
||||
discover := &nvmlDiscover{
|
||||
logger: logger,
|
||||
nvml: nvml,
|
||||
migCaps: migCaps,
|
||||
nvidiaDevices: nvidiaDevices,
|
||||
}
|
||||
|
||||
return discover, nil
|
||||
}
|
||||
|
||||
// hasMigSupport checks if MIG device discovery is supported.
|
||||
// Cases where this will be disabled include where no MIG minors file is
|
||||
// present.
|
||||
func (d nvmlDiscover) hasMigSupport() bool {
|
||||
return len(d.migCaps) > 0
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) Devices() ([]Device, error) {
|
||||
ret := d.nvml.Init()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error initalizing NVML: %v", ret.Error())
|
||||
}
|
||||
defer d.tryShutdownNVML()
|
||||
|
||||
c, ret := d.nvml.DeviceGetCount()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting device count: %v", ret.Error())
|
||||
}
|
||||
|
||||
var handles []nvml.Device
|
||||
for i := 0; i < c; i++ {
|
||||
handle, ret := d.nvml.DeviceGetHandleByIndex(i)
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting device handle for device %v: %v", i, ret.Error())
|
||||
}
|
||||
|
||||
if !d.hasMigSupport() {
|
||||
handles = append(handles, handle)
|
||||
continue
|
||||
}
|
||||
|
||||
migHandles, err := getMIGHandlesForDevice(handle)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting MIG handles for device: %v", err)
|
||||
}
|
||||
if len(migHandles) == 0 {
|
||||
handles = append(handles, handle)
|
||||
}
|
||||
handles = append(handles, migHandles...)
|
||||
}
|
||||
|
||||
return d.devicesByHandle(handles)
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) devicesByHandle(handles []nvml.Device) ([]Device, error) {
|
||||
var devices []Device
|
||||
var largestMinorNumber int
|
||||
for _, h := range handles {
|
||||
device, err := d.deviceFromNVMLHandle(h)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing device from handle %v: %v", h, err)
|
||||
}
|
||||
devices = append(devices, device)
|
||||
|
||||
if largestMinorNumber < device.DeviceNodes[0].Minor {
|
||||
largestMinorNumber = device.DeviceNodes[0].Minor
|
||||
}
|
||||
}
|
||||
|
||||
controlDevices, err := d.getControlDevices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting control devices: %v", err)
|
||||
}
|
||||
devices = append(devices, controlDevices...)
|
||||
|
||||
if d.hasMigSupport() {
|
||||
migControlDevices, err := d.getMigControlDevices(largestMinorNumber)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting MIG control devices: %v", err)
|
||||
}
|
||||
devices = append(devices, migControlDevices...)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) deviceFromNVMLHandle(handle nvml.Device) (Device, error) {
|
||||
if d.hasMigSupport() {
|
||||
isMigDevice, ret := handle.IsMigDeviceHandle()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error checking device handle: %v", ret.Error())
|
||||
}
|
||||
|
||||
if isMigDevice {
|
||||
return d.deviceFromMIGDeviceHandle(handle)
|
||||
}
|
||||
}
|
||||
|
||||
return d.deviceFromFullDeviceHandle(handle)
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) deviceFromFullDeviceHandle(handle nvml.Device) (Device, error) {
|
||||
index, ret := handle.GetIndex()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
|
||||
}
|
||||
|
||||
uuid, ret := handle.GetUUID()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
|
||||
}
|
||||
|
||||
pciInfo, ret := handle.GetPciInfo()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting PCI info: %v", ret.Error())
|
||||
}
|
||||
busID := NewPCIBusID(pciInfo)
|
||||
|
||||
minor, ret := handle.GetMinorNumber()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting minor number: %v", ret.Error())
|
||||
}
|
||||
|
||||
nvidiaGPUDevice, exists := d.nvidiaDevices.Get(nvidiaGPUDeviceName)
|
||||
if !exists {
|
||||
return Device{}, fmt.Errorf("device for '%v' does not exist", nvidiaGPUDeviceName)
|
||||
}
|
||||
|
||||
deviceNode := DeviceNode{
|
||||
Path: DevicePath(fmt.Sprintf("/dev/nvidia%d", minor)),
|
||||
Major: nvidiaGPUDevice.Major,
|
||||
Minor: minor,
|
||||
}
|
||||
|
||||
device := Device{
|
||||
Index: fmt.Sprintf("%d", index),
|
||||
PCIBusID: busID,
|
||||
UUID: uuid,
|
||||
DeviceNodes: []DeviceNode{deviceNode},
|
||||
ProcPaths: []ProcPath{busID.GetProcPath()},
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) deviceFromMIGDeviceHandle(handle nvml.Device) (Device, error) {
|
||||
parent, ret := handle.GetDeviceHandleFromMigDeviceHandle()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting parent device handle: %v", ret.Error())
|
||||
}
|
||||
|
||||
gpu, ret := parent.GetMinorNumber()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting GPU minor number: %v", ret.Error())
|
||||
}
|
||||
|
||||
parentDevice, err := d.deviceFromFullDeviceHandle(parent)
|
||||
if err != nil {
|
||||
return Device{}, fmt.Errorf("error getting parent device: %v", err)
|
||||
}
|
||||
|
||||
index, ret := handle.GetIndex()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
|
||||
}
|
||||
|
||||
uuid, ret := handle.GetUUID()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
|
||||
}
|
||||
|
||||
capDeviceNodes := []DeviceNode{}
|
||||
procPaths, err := getProcPathsForMigDevice(gpu, handle)
|
||||
if err != nil {
|
||||
return Device{}, fmt.Errorf("error getting proc paths for MIG device: %v", err)
|
||||
}
|
||||
|
||||
for _, p := range procPaths {
|
||||
capDeviceNode, ok := d.migCaps[p]
|
||||
if !ok {
|
||||
return Device{}, fmt.Errorf("could not determine cap device path for %v", p)
|
||||
}
|
||||
capDeviceNodes = append(capDeviceNodes, capDeviceNode)
|
||||
}
|
||||
|
||||
device := Device{
|
||||
Index: fmt.Sprintf("%s:%d", parentDevice.Index, index),
|
||||
UUID: uuid,
|
||||
DeviceNodes: append(parentDevice.DeviceNodes, capDeviceNodes...),
|
||||
ProcPaths: append(parentDevice.ProcPaths, procPaths...),
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) getControlDevices() ([]Device, error) {
|
||||
devices := []struct {
|
||||
name string
|
||||
path string
|
||||
minor int
|
||||
}{
|
||||
// TODO: Where is the best place to find these device Minors programatically?
|
||||
{nvidiaGPUDeviceName, "/dev/nvidia-modeset", 254},
|
||||
{nvidiaGPUDeviceName, "/dev/nvidiactl", 255},
|
||||
{nvidiaUVMDeviceName, "/dev/nvidia-uvm", 0},
|
||||
{nvidiaUVMDeviceName, "/dev/nvidia-uvm-tools", 1},
|
||||
}
|
||||
|
||||
var controlDevices []Device
|
||||
for _, info := range devices {
|
||||
device, exists := d.nvidiaDevices.Get(info.name)
|
||||
if !exists {
|
||||
d.logger.Warnf("device name %v not defined; skipping control devices %v", info.name, info.path)
|
||||
continue
|
||||
}
|
||||
|
||||
deviceNode := DeviceNode{
|
||||
Path: DevicePath(info.path),
|
||||
Major: device.Major,
|
||||
Minor: info.minor,
|
||||
}
|
||||
|
||||
controlDevices = append(controlDevices, Device{
|
||||
UUID: ControlDeviceUUID,
|
||||
DeviceNodes: []DeviceNode{deviceNode},
|
||||
ProcPaths: []ProcPath{},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return controlDevices, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) getMigControlDevices(numGpus int) ([]Device, error) {
|
||||
targets := map[string]ProcPath{
|
||||
MIGConfigDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/config"),
|
||||
MIGMonitorDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"),
|
||||
}
|
||||
|
||||
var devices []Device
|
||||
for id, procPath := range targets {
|
||||
deviceNode, exists := d.migCaps[procPath]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("device node for '%v' is undefined", procPath)
|
||||
}
|
||||
|
||||
var procPaths []ProcPath
|
||||
for gpu := 0; gpu <= numGpus; gpu++ {
|
||||
procPaths = append(procPaths, ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig", gpu)))
|
||||
}
|
||||
procPaths = append(procPaths, procPath)
|
||||
|
||||
devices = append(devices, Device{
|
||||
UUID: id,
|
||||
DeviceNodes: []DeviceNode{deviceNode},
|
||||
ProcPaths: procPaths,
|
||||
})
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func getProcPathsForMigDevice(gpu int, handle nvml.Device) ([]ProcPath, error) {
|
||||
gi, ret := handle.GetGPUInstanceId()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting GPU instance ID: %v", ret.Error())
|
||||
}
|
||||
|
||||
ci, ret := handle.GetComputeInstanceId()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting comput instance ID: %v", ret.Error())
|
||||
}
|
||||
|
||||
procPaths := []ProcPath{
|
||||
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/access", gpu, gi)),
|
||||
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/ci%d/access", gpu, gi, ci)),
|
||||
}
|
||||
|
||||
return procPaths, nil
|
||||
}
|
||||
|
||||
func getMIGHandlesForDevice(handle nvml.Device) ([]nvml.Device, error) {
|
||||
currentMigMode, _, ret := handle.GetMigMode()
|
||||
if ret.Value() == nvml.ERROR_NOT_SUPPORTED {
|
||||
return nil, nil
|
||||
}
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting MIG mode for device: %v", ret.Error())
|
||||
}
|
||||
if currentMigMode == nvml.DEVICE_MIG_DISABLE {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
maxMigDeviceCount, ret := handle.GetMaxMigDeviceCount()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting number of MIG devices: %v", ret.Error())
|
||||
}
|
||||
|
||||
var migHandles []nvml.Device
|
||||
for mi := 0; mi < maxMigDeviceCount; mi++ {
|
||||
migHandle, ret := handle.GetMigDeviceHandleByIndex(mi)
|
||||
if ret.Value() == nvml.ERROR_NOT_FOUND {
|
||||
continue
|
||||
}
|
||||
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting MIG device %v: %v", mi, ret.Error())
|
||||
}
|
||||
|
||||
migHandles = append(migHandles, migHandle)
|
||||
}
|
||||
|
||||
return migHandles, nil
|
||||
}
|
||||
|
||||
func (d *nvmlDiscover) tryShutdownNVML() {
|
||||
ret := d.nvml.Shutdown()
|
||||
if ret.Value() != nvml.SUCCESS {
|
||||
d.logger.Warnf("Could not shut down NVML: %v", ret.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// NewPCIBusID provides a utility function that returns the string representation
|
||||
// of the bus ID.
|
||||
func NewPCIBusID(p nvml.PciInfo) PCIBusID {
|
||||
var bytes []byte
|
||||
for _, b := range p.BusId {
|
||||
if byte(b) == '\x00' {
|
||||
break
|
||||
}
|
||||
bytes = append(bytes, byte(b))
|
||||
}
|
||||
return PCIBusID(string(bytes))
|
||||
}
|
||||
|
||||
// GetProcPath returns the path in /proc associated with the PCI bus ID
|
||||
func (p PCIBusID) GetProcPath() ProcPath {
|
||||
id := strings.ToLower(p.String())
|
||||
|
||||
if strings.HasPrefix(id, "0000") {
|
||||
id = id[4:]
|
||||
}
|
||||
return ProcPath(filepath.Join("/proc/driver/nvidia/gpus", id))
|
||||
}
|
||||
|
||||
func (p PCIBusID) String() string {
|
||||
return string(p)
|
||||
}
|
||||
44
pkg/discover/discover_nvml_mig.go
Normal file
44
pkg/discover/discover_nvml_mig.go
Normal file
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvcaps"
|
||||
)
|
||||
|
||||
// getMigCaps returns a mapping of MIG capability path to device nodes
|
||||
func getMigCaps(capDeviceMajor int) (map[ProcPath]DeviceNode, error) {
|
||||
migCaps, err := nvcaps.LoadMigMinors()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading MIG minors: %v", err)
|
||||
}
|
||||
return getMigCapsFromMigMinors(migCaps, capDeviceMajor), nil
|
||||
}
|
||||
|
||||
func getMigCapsFromMigMinors(migCaps map[nvcaps.MigCap]nvcaps.MigMinor, capDeviceMajor int) map[ProcPath]DeviceNode {
|
||||
capsDevicePaths := make(map[ProcPath]DeviceNode)
|
||||
for cap, minor := range migCaps {
|
||||
capsDevicePaths[ProcPath(cap.ProcPath())] = DeviceNode{
|
||||
Path: DevicePath(minor.DevicePath()),
|
||||
Major: capDeviceMajor,
|
||||
Minor: int(minor),
|
||||
}
|
||||
}
|
||||
return capsDevicePaths
|
||||
}
|
||||
41
pkg/discover/discover_nvml_mig_test.go
Normal file
41
pkg/discover/discover_nvml_mig_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvcaps"
|
||||
)
|
||||
|
||||
func TestGetMigCaps(t *testing.T) {
|
||||
migMinors := map[nvcaps.MigCap]nvcaps.MigMinor{
|
||||
"config": 1,
|
||||
"monitor": 2,
|
||||
"gpu0/gi0/access": 3,
|
||||
"gpu0/gi0/ci0/access": 4,
|
||||
}
|
||||
|
||||
migCapMajor := 999
|
||||
migCaps := getMigCapsFromMigMinors(migMinors, migCapMajor)
|
||||
|
||||
require.Len(t, migCaps, len(migMinors))
|
||||
for _, c := range migCaps {
|
||||
require.Equal(t, migCapMajor, c.Major)
|
||||
}
|
||||
}
|
||||
219
pkg/discover/discover_nvml_test.go
Normal file
219
pkg/discover/discover_nvml_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/proc"
|
||||
)
|
||||
|
||||
const (
|
||||
nvidiaGPUDeviceMajorDefault = 195
|
||||
nvidiaCapsDeviceMajorDefault = 235
|
||||
)
|
||||
|
||||
func newTestDiscover() nvmlDiscover {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
nvml := nvml.NewMockNVMLServer(nvml.NewMockA100Device(0))
|
||||
|
||||
return nvmlDiscover{
|
||||
logger: logger,
|
||||
nvml: nvml,
|
||||
nvidiaDevices: proc.NewMockNvidiaDevices(
|
||||
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func newTestDiscoverWithMIG() nvmlDiscover {
|
||||
device := &nvml.MockA100Device{
|
||||
Index: 0,
|
||||
MigMode: nvml.DEVICE_MIG_ENABLE,
|
||||
GpuInstances: make(map[*nvml.MockA100GpuInstance]struct{}),
|
||||
GpuInstanceCounter: 0,
|
||||
}
|
||||
// Create a single gi and ci on the device
|
||||
gpuInstanceProfileInfo := &nvml.GpuInstanceProfileInfo{
|
||||
Id: nvml.GPU_INSTANCE_PROFILE_7_SLICE,
|
||||
}
|
||||
gi, _ := device.CreateGpuInstance(gpuInstanceProfileInfo)
|
||||
|
||||
computeInstanceProfileInfo := &nvml.ComputeInstanceProfileInfo{
|
||||
Id: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE,
|
||||
}
|
||||
_, _ = gi.CreateComputeInstance(computeInstanceProfileInfo)
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
return nvmlDiscover{
|
||||
logger: logger,
|
||||
nvml: nvml.NewMockNVMLServer(device),
|
||||
migCaps: map[ProcPath]DeviceNode{
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
|
||||
Minor: 3,
|
||||
Major: 235,
|
||||
},
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
|
||||
Minor: 4,
|
||||
Major: 235,
|
||||
},
|
||||
ProcPath("/proc/driver/nvidia/capabilities/mig/config"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
|
||||
Minor: 1,
|
||||
Major: 235,
|
||||
},
|
||||
ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"): {
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
|
||||
Minor: 2,
|
||||
Major: 235,
|
||||
},
|
||||
},
|
||||
nvidiaDevices: proc.NewMockNvidiaDevices(
|
||||
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
|
||||
proc.Device{Name: nvidiaCapsDeviceName, Major: nvidiaCapsDeviceMajorDefault},
|
||||
),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDiscoverNvmlDevices(t *testing.T) {
|
||||
d := newTestDiscover()
|
||||
devices, err := d.Devices()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, devices, 3)
|
||||
|
||||
device := devices[0]
|
||||
|
||||
require.Equal(t, "0", device.Index)
|
||||
require.Equal(t, "GPU-0", device.UUID)
|
||||
require.Equal(t, "0000FFFF:FF:FF.F", device.PCIBusID.String())
|
||||
|
||||
expectedDeviceNodes := []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia0"),
|
||||
Minor: 0,
|
||||
Major: nvidiaGPUDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, device.DeviceNodes)
|
||||
|
||||
expectedProcPaths := []ProcPath{ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f")}
|
||||
require.Equal(t, expectedProcPaths, device.ProcPaths)
|
||||
}
|
||||
|
||||
func TestDiscoverNvmlMigDevices(t *testing.T) {
|
||||
d := newTestDiscoverWithMIG()
|
||||
|
||||
devices, err := d.Devices()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, devices, 5)
|
||||
|
||||
mig := devices[0]
|
||||
|
||||
require.Equal(t, "0:0", mig.Index)
|
||||
require.Equal(t, "MIG-0", mig.UUID)
|
||||
require.Empty(t, mig.PCIBusID)
|
||||
|
||||
expectedDeviceNodes := []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia0"),
|
||||
Minor: 0,
|
||||
Major: nvidiaGPUDeviceMajorDefault,
|
||||
},
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
|
||||
Minor: 3,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
|
||||
Minor: 4,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, mig.DeviceNodes)
|
||||
|
||||
expectedProcPaths := []ProcPath{
|
||||
ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f"),
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"),
|
||||
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"),
|
||||
}
|
||||
require.Equal(t, expectedProcPaths, mig.ProcPaths)
|
||||
|
||||
var config *Device
|
||||
var monitor *Device
|
||||
for i, d := range devices {
|
||||
if d.UUID == "CONFIG" {
|
||||
config = &devices[i]
|
||||
}
|
||||
if d.UUID == "MONITOR" {
|
||||
monitor = &devices[i]
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, config)
|
||||
require.NotNil(t, monitor)
|
||||
|
||||
require.Equal(t, "CONFIG", config.UUID)
|
||||
|
||||
expectedDeviceNodes = []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
|
||||
Minor: 1,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, config.DeviceNodes)
|
||||
|
||||
require.Contains(t, config.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/config"))
|
||||
require.Len(t, config.ProcPaths, 2)
|
||||
|
||||
require.Equal(t, "MONITOR", monitor.UUID)
|
||||
|
||||
expectedDeviceNodes = []DeviceNode{
|
||||
{
|
||||
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
|
||||
Minor: 2,
|
||||
Major: nvidiaCapsDeviceMajorDefault,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedDeviceNodes, monitor.DeviceNodes)
|
||||
|
||||
require.Contains(t, monitor.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"))
|
||||
require.Len(t, monitor.ProcPaths, 2)
|
||||
|
||||
}
|
||||
|
||||
func TestPCIBusID(t *testing.T) {
|
||||
testCases := map[string]ProcPath{
|
||||
"0000FFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffff:ff:ff.f",
|
||||
"FFFFFFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffffffff:ff:ff.f",
|
||||
}
|
||||
|
||||
for busID, procPath := range testCases {
|
||||
p := PCIBusID(busID)
|
||||
require.Equal(t, busID, p.String())
|
||||
require.Equal(t, procPath, p.GetProcPath())
|
||||
}
|
||||
}
|
||||
71
pkg/discover/hooks.go
Normal file
71
pkg/discover/hooks.go
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type hooks struct {
|
||||
None
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
var _ Discover = (*hooks)(nil)
|
||||
|
||||
// NewHooks creates a discoverer for linux containers
|
||||
func NewHooks() Discover {
|
||||
return NewHooksWithLogger(log.StandardLogger())
|
||||
}
|
||||
|
||||
// NewHooksWithLogger creates a discoverer as with NewHooks with the specified logger
|
||||
func NewHooksWithLogger(logger *log.Logger) Discover {
|
||||
h := hooks{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h hooks) Hooks() ([]Hook, error) {
|
||||
var hooks []Hook
|
||||
|
||||
hooks = append(hooks, newLdconfigHook())
|
||||
|
||||
return hooks, nil
|
||||
}
|
||||
|
||||
func newLdconfigHook() Hook {
|
||||
const rootPattern = "@Root.Path@"
|
||||
|
||||
h := Hook{
|
||||
Path: "/sbin/ldconfig",
|
||||
Args: []string{
|
||||
// TODO: Testing seems to indicate that this is -v flag is required
|
||||
"-v",
|
||||
"-r", rootPattern,
|
||||
},
|
||||
// TODO: CreateContainer hooks were only added to a later OCI spec version
|
||||
// We will have to find a way to deal with OCI versions before 1.0.2
|
||||
HookName: "create-container",
|
||||
Labels: map[string]string{
|
||||
"min-oci-version": "1.0.2",
|
||||
},
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
52
pkg/discover/ipc.go
Normal file
52
pkg/discover/ipc.go
Normal file
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
|
||||
)
|
||||
|
||||
// NewIPCMounts creates a discoverer for IPC sockets
|
||||
func NewIPCMounts(root string) Discover {
|
||||
return NewIPCMountsWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewIPCMountsWithLogger creates a discovered as with NewIPCMounts with the
|
||||
// specified logger.
|
||||
func NewIPCMountsWithLogger(logger *log.Logger, root string) Discover {
|
||||
d := mounts{
|
||||
logger: logger,
|
||||
lookup: lookup.NewFileLocatorWithLogger(logger, root),
|
||||
required: requiredIPCs,
|
||||
}
|
||||
|
||||
return &d
|
||||
}
|
||||
|
||||
var requiredIPCs = map[string][]string{
|
||||
"nvidia-persistenced": {
|
||||
"/var/run/nvidia-persistenced/socket",
|
||||
},
|
||||
"nvidia-fabricmanager": {
|
||||
"/var/run/nvidia-fabricmanager/socket",
|
||||
},
|
||||
// TODO: This can be controlled by the NV_MPS_PIPE_DIR envvar
|
||||
"nvidia-mps": {
|
||||
"/tmp/nvidia-mps",
|
||||
},
|
||||
}
|
||||
56
pkg/discover/ipc_test.go
Normal file
56
pkg/discover/ipc_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIPCDiscover(t *testing.T) {
|
||||
|
||||
ipcLookup := map[string]string{
|
||||
"/var/run/nvidia-persistenced/socket": "/var/run/nvidia-persistenced/socket",
|
||||
"/var/run/nvidia-fabricmanager/socket": "fm-socket",
|
||||
}
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
d := NewIPCMountsWithLogger(logger, "")
|
||||
|
||||
// Override lookup for testing
|
||||
mockLocator := NewLocatorMockFromMap(ipcLookup)
|
||||
d.(*mounts).lookup = mockLocator
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, []Mount{
|
||||
{Path: "/var/run/nvidia-persistenced/socket", Labels: map[string]string{}},
|
||||
{Path: "fm-socket", Labels: map[string]string{}}}, mounts)
|
||||
|
||||
require.Len(t, mockLocator.LocateCalls(), 3)
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
100
pkg/discover/libraries.go
Normal file
100
pkg/discover/libraries.go
Normal file
@@ -0,0 +1,100 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
|
||||
)
|
||||
|
||||
// NewLibraries constructs discoverer for libraries
|
||||
func NewLibraries(root string) (Discover, error) {
|
||||
return NewLibrariesWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewLibrariesWithLogger constructs discoverer for libraries with the specified logger
|
||||
func NewLibrariesWithLogger(logger *log.Logger, root string) (Discover, error) {
|
||||
lookup, err := lookup.NewLibraryLocatorWithLogger(logger, root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing locator: %v", err)
|
||||
}
|
||||
|
||||
d := mounts{
|
||||
logger: logger,
|
||||
lookup: lookup,
|
||||
required: requiredLibraries,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// requiredLibraries defines a set of libraries and their labels
|
||||
var requiredLibraries = map[string][]string{
|
||||
"utility": {
|
||||
"libnvidia-ml.so", /* Management library */
|
||||
"libnvidia-cfg.so", /* GPU configuration */
|
||||
},
|
||||
"compute": {
|
||||
"libcuda.so", /* CUDA driver library */
|
||||
"libnvidia-opencl.so", /* NVIDIA OpenCL ICD */
|
||||
"libnvidia-ptxjitcompiler.so", /* PTX-SASS JIT compiler (used by libcuda) */
|
||||
"libnvidia-fatbinaryloader.so", /* fatbin loader (used by libcuda) */
|
||||
"libnvidia-allocator.so", /* NVIDIA allocator runtime library */
|
||||
"libnvidia-compiler.so", /* NVVM-PTX compiler for OpenCL (used by libnvidia-opencl) */
|
||||
},
|
||||
"video": {
|
||||
"libvdpau_nvidia.so", /* NVIDIA VDPAU ICD */
|
||||
"libnvidia-encode.so", /* Video encoder */
|
||||
"libnvidia-opticalflow.so", /* NVIDIA Opticalflow library */
|
||||
"libnvcuvid.so", /* Video decoder */
|
||||
},
|
||||
"graphics": {
|
||||
//"libnvidia-egl-wayland.so", /* EGL wayland platform extension (used by libEGL_nvidia) */
|
||||
"libnvidia-eglcore.so", /* EGL core (used by libGLES*[_nvidia] and libEGL_nvidia) */
|
||||
"libnvidia-glcore.so", /* OpenGL core (used by libGL or libGLX_nvidia) */
|
||||
"libnvidia-tls.so", /* Thread local storage (used by libGL or libGLX_nvidia) */
|
||||
"libnvidia-glsi.so", /* OpenGL system interaction (used by libEGL_nvidia) */
|
||||
"libnvidia-fbc.so", /* Framebuffer capture */
|
||||
"libnvidia-ifr.so", /* OpenGL framebuffer capture */
|
||||
"libnvidia-rtcore.so", /* Optix */
|
||||
"libnvoptix.so", /* Optix */
|
||||
},
|
||||
"glvnd": {
|
||||
//"libGLX.so", /* GLX ICD loader */
|
||||
//"libOpenGL.so", /* OpenGL ICD loader */
|
||||
//"libGLdispatch.so", /* OpenGL dispatch (used by libOpenGL, libEGL and libGLES*) */
|
||||
"libGLX_nvidia.so", /* OpenGL/GLX ICD */
|
||||
"libEGL_nvidia.so", /* EGL ICD */
|
||||
"libGLESv2_nvidia.so", /* OpenGL ES v2 ICD */
|
||||
"libGLESv1_CM_nvidia.so", /* OpenGL ES v1 common profile ICD */
|
||||
"libnvidia-glvkspirv.so", /* SPIR-V Lib for Vulkan */
|
||||
"libnvidia-cbl.so", /* VK_NV_ray_tracing */
|
||||
},
|
||||
"compat32": {
|
||||
"libGL.so", /* OpenGL/GLX legacy _or_ compatibility wrapper (GLVND) */
|
||||
"libEGL.so", /* EGL legacy _or_ ICD loader (GLVND) */
|
||||
"libGLESv1_CM.so", /* OpenGL ES v1 common profile legacy _or_ ICD loader (GLVND) */
|
||||
"libGLESv2.so", /* OpenGL ES v2 legacy _or_ ICD loader (GLVND) */
|
||||
},
|
||||
"ngx": {
|
||||
"libnvidia-ngx.so", /* NGX library */
|
||||
},
|
||||
"dxcore": {
|
||||
"libdxcore.so", /* Core library for dxcore support */
|
||||
},
|
||||
}
|
||||
56
pkg/discover/libraries_test.go
Normal file
56
pkg/discover/libraries_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLibraries(t *testing.T) {
|
||||
|
||||
libraryLookup := map[string]string{
|
||||
"libcuda.so": "/lib/libcuda.so.999.99",
|
||||
"libversion.so": "/lib/libversion.so.111.11",
|
||||
"libother.so": "/lib/libother.so.999.99",
|
||||
}
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
d, err := NewLibrariesWithLogger(logger, "")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Override lookup for testing
|
||||
mockLocator := NewLocatorMockFromMap(libraryLookup)
|
||||
d.(*mounts).lookup = mockLocator
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.ElementsMatch(t, []Mount{{Path: "/lib/libcuda.so.999.99", Labels: map[string]string{}}}, mounts)
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
125
pkg/discover/mounts.go
Normal file
125
pkg/discover/mounts.go
Normal file
@@ -0,0 +1,125 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
|
||||
)
|
||||
|
||||
const (
|
||||
capabilityLabel = "capability"
|
||||
versionLabel = "version"
|
||||
)
|
||||
|
||||
// mounts is a generic discoverer for Mounts. It is customized by specifying the
|
||||
// required entities as a key-value pair as well as a Locator that is used to
|
||||
// identify the mounts that are to be included.
|
||||
type mounts struct {
|
||||
None
|
||||
logger *log.Logger
|
||||
lookup lookup.Locator
|
||||
required map[string][]string
|
||||
}
|
||||
|
||||
var _ Discover = (*mounts)(nil)
|
||||
|
||||
func (d mounts) Mounts() ([]Mount, error) {
|
||||
mounts, err := d.uniqueMounts()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering mounts: %v", err)
|
||||
}
|
||||
|
||||
return mounts.Slice(), nil
|
||||
}
|
||||
|
||||
func (d mounts) uniqueMounts() (mountsByPath, error) {
|
||||
if d.lookup == nil {
|
||||
return nil, fmt.Errorf("no lookup defined")
|
||||
}
|
||||
|
||||
mounts := make(mountsByPath)
|
||||
|
||||
for id, keys := range d.required {
|
||||
for _, key := range keys {
|
||||
d.logger.Debugf("Locating %v [%v]", key, id)
|
||||
located, err := d.lookup.Locate(key)
|
||||
if err != nil {
|
||||
d.logger.Warnf("Could not locate %v [%v]: %v", key, id, err)
|
||||
continue
|
||||
}
|
||||
d.logger.Infof("Located %v [%v]: %v", key, id, located)
|
||||
for _, p := range located {
|
||||
// TODO: We need to add labels
|
||||
mount := newMount(p)
|
||||
mounts.Put(mount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mounts, nil
|
||||
}
|
||||
|
||||
type mountsByPath map[string]Mount
|
||||
|
||||
func (m mountsByPath) Slice() []Mount {
|
||||
var mounts []Mount
|
||||
for _, mount := range m {
|
||||
mounts = append(mounts, mount)
|
||||
}
|
||||
|
||||
return mounts
|
||||
}
|
||||
|
||||
func (m *mountsByPath) Put(value Mount) {
|
||||
key := value.Path
|
||||
mount, exists := (*m)[key]
|
||||
if !exists {
|
||||
(*m)[key] = value
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range value.Labels {
|
||||
mount.Labels[k] = v
|
||||
}
|
||||
(*m)[key] = mount
|
||||
}
|
||||
|
||||
// NewMountForCapability creates a mount with the specified capability label
|
||||
func NewMountForCapability(path string, capability string) Mount {
|
||||
return newMount(path, capabilityLabel, capability)
|
||||
}
|
||||
|
||||
// NewMountForVersion creates a mount with the specified version label
|
||||
func NewMountForVersion(path string, version string) Mount {
|
||||
return newMount(path, versionLabel, version)
|
||||
}
|
||||
|
||||
func newMount(path string, labels ...string) Mount {
|
||||
l := make(map[string]string)
|
||||
|
||||
for i := 0; i < len(labels)-1; i += 2 {
|
||||
l[labels[i]] = labels[i+1]
|
||||
}
|
||||
|
||||
return Mount{
|
||||
Path: path,
|
||||
Labels: l,
|
||||
}
|
||||
}
|
||||
53
pkg/discover/mounts_test.go
Normal file
53
pkg/discover/mounts_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
|
||||
)
|
||||
|
||||
func TestMountsReturnsErrorForNoLookup(t *testing.T) {
|
||||
d := mounts{}
|
||||
mounts, err := d.Mounts()
|
||||
|
||||
require.Error(t, err)
|
||||
require.Len(t, mounts, 0)
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
|
||||
func NewLocatorMockFromMap(lookupMap map[string]string) *lookup.LocatorMock {
|
||||
return &lookup.LocatorMock{
|
||||
LocateFunc: func(key string) ([]string, error) {
|
||||
value, exists := lookupMap[key]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("key %v not found", key)
|
||||
}
|
||||
return []string{value}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
38
pkg/discover/none.go
Normal file
38
pkg/discover/none.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
// None is a null discoverer that returns an empty list of devices and
|
||||
// mounts.
|
||||
type None struct{}
|
||||
|
||||
var _ Discover = (*None)(nil)
|
||||
|
||||
// Devices returns an empty list of devices
|
||||
func (e None) Devices() ([]Device, error) {
|
||||
return []Device{}, nil
|
||||
}
|
||||
|
||||
// Mounts returns an empty list of mounts
|
||||
func (e None) Mounts() ([]Mount, error) {
|
||||
return []Mount{}, nil
|
||||
}
|
||||
|
||||
// Hooks returns an empty list of hooks
|
||||
func (e None) Hooks() ([]Hook, error) {
|
||||
return []Hook{}, nil
|
||||
}
|
||||
39
pkg/discover/none_test.go
Normal file
39
pkg/discover/none_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNone(t *testing.T) {
|
||||
d := None{}
|
||||
|
||||
devices, err := d.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, devices)
|
||||
|
||||
mounts, err := d.Mounts()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, mounts)
|
||||
|
||||
hooks, err := d.Hooks()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, hooks)
|
||||
}
|
||||
71
pkg/discover/nvml_server.go
Normal file
71
pkg/discover/nvml_server.go
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
|
||||
)
|
||||
|
||||
type nvmlServer struct {
|
||||
logger *log.Logger
|
||||
composite
|
||||
}
|
||||
|
||||
var _ Discover = (*nvmlServer)(nil)
|
||||
|
||||
// NewNVMLServer constructs a discoverer for server systems using NVML to discover devices
|
||||
func NewNVMLServer(root string) (Discover, error) {
|
||||
return NewNVMLServerWithLogger(log.StandardLogger(), root)
|
||||
}
|
||||
|
||||
// NewNVMLServerWithLogger constructs a discoverer for server systems using NVML to discover devices with
|
||||
// the specified logger
|
||||
func NewNVMLServerWithLogger(logger *log.Logger, root string) (Discover, error) {
|
||||
return createNVMLServer(logger, nvml.New(), root)
|
||||
}
|
||||
|
||||
func createNVMLServer(logger *log.Logger, nvml nvml.Interface, root string) (Discover, error) {
|
||||
d := nvmlServer{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
devices, err := NewNVMLDiscoverWithLogger(logger, nvml)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing NVML device discoverer: %v", err)
|
||||
}
|
||||
|
||||
libraries, err := NewLibrariesWithLogger(logger, root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing library discoverer: %v", err)
|
||||
}
|
||||
|
||||
d.add(
|
||||
// Device discovery
|
||||
devices,
|
||||
// Mounts discovery
|
||||
libraries,
|
||||
NewBinaryMountsWithLogger(logger, root),
|
||||
NewIPCMountsWithLogger(logger, root),
|
||||
// Hook discovery
|
||||
NewHooksWithLogger(logger),
|
||||
)
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
66
pkg/discover/nvml_server_test.go
Normal file
66
pkg/discover/nvml_server_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/proc"
|
||||
)
|
||||
|
||||
const (
|
||||
testMajor = 999
|
||||
)
|
||||
|
||||
func TestNVMLServerConstructor(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
nvml := nvml.NewMockNVMLOnLunaServer()
|
||||
|
||||
d, err := createNVMLServer(logger, nvml, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
instance := d.(*nvmlServer)
|
||||
require.Len(t, instance.discoverers, 1+3+1)
|
||||
|
||||
// We need to override the nvidiaDevices member of the nvmlDiscovery
|
||||
// TODO: Use a mock instead, or allow for injection into a constructor
|
||||
instance.discoverers[0].(*nvmlDiscover).nvidiaDevices = &mockNvidiaDevices{}
|
||||
|
||||
devices, err := d.Devices()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, devices)
|
||||
|
||||
_, err = d.Mounts()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type mockNvidiaDevices struct{}
|
||||
|
||||
var _ proc.NvidiaDevices = (*mockNvidiaDevices)(nil)
|
||||
|
||||
func (d mockNvidiaDevices) Get(name string) (proc.Device, bool) {
|
||||
return proc.Device{Name: name, Major: testMajor}, true
|
||||
}
|
||||
|
||||
func (d mockNvidiaDevices) Exists(string) bool {
|
||||
return false
|
||||
}
|
||||
129
pkg/ensure/devices.go
Normal file
129
pkg/ensure/devices.go
Normal file
@@ -0,0 +1,129 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package ensure
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
type ensureDevices struct {
|
||||
logger *log.Logger
|
||||
discover.Discover
|
||||
lookup lookup.Locator
|
||||
root string
|
||||
}
|
||||
|
||||
// NewEnsureDevices creates a discoverer that wraps the specified discoverer and ensures that the
|
||||
// device nodes for the discoverer are created. If a root is specified, the device nodes
|
||||
// rooted there are also created.
|
||||
func NewEnsureDevices(d discover.Discover, root string) discover.Discover {
|
||||
return NewEnsureDevicesWithLogger(log.StandardLogger(), d, root)
|
||||
}
|
||||
|
||||
// NewEnsureDevicesWithLogger creates a discoverer that wraps the specified discoverer and ensures that the
|
||||
// device nodes for the discoverer are created. If a root is specified, the device nodes
|
||||
// rooted there are also created. The specified logger is used.
|
||||
func NewEnsureDevicesWithLogger(logger *log.Logger, d discover.Discover, root string) discover.Discover {
|
||||
e := ensureDevices{
|
||||
Discover: d,
|
||||
logger: logger,
|
||||
lookup: lookup.NewPathLocatorWithLogger(logger, root),
|
||||
root: root,
|
||||
}
|
||||
return &e
|
||||
}
|
||||
|
||||
func (d ensureDevices) Devices() ([]discover.Device, error) {
|
||||
devices, err := d.Discover.Devices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering devices: %v", err)
|
||||
}
|
||||
for _, di := range devices {
|
||||
for _, dn := range di.DeviceNodes {
|
||||
d.deviceNode(dn)
|
||||
}
|
||||
}
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (d ensureDevices) deviceNode(dn discover.DeviceNode) error {
|
||||
err := d.device(dn.Path, dn.Major, dn.Minor)
|
||||
if err != nil {
|
||||
d.logger.Errorf("Error creating device node %+v: %v", dn, err)
|
||||
}
|
||||
|
||||
if d.root != "" && d.root != "/" {
|
||||
rootedPath := discover.DevicePath(filepath.Join(d.root, string(dn.Path)))
|
||||
err = d.device(rootedPath, dn.Major, dn.Minor)
|
||||
if err != nil {
|
||||
d.logger.Errorf("Error creating device node %+v: %v", dn, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d ensureDevices) device(path discover.DevicePath, major int, minor int) error {
|
||||
// TODO: We may want to check that the device node has the required permissions
|
||||
_, err := os.Stat(string(path))
|
||||
if err == nil {
|
||||
d.logger.Infof("Device node %v already exists", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
d.logger.Errorf("Error getting info for device node %v: %v", path, err)
|
||||
return fmt.Errorf("error getting device node info: %w", err)
|
||||
}
|
||||
// See: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#runfile-verifications
|
||||
|
||||
// TODO: We should use nvidia-modprobe or a tool based off that instead
|
||||
args := []string{
|
||||
"-m", "666",
|
||||
string(path), "c",
|
||||
fmt.Sprint(major),
|
||||
fmt.Sprint(minor),
|
||||
}
|
||||
|
||||
return d.run("mknod", args...)
|
||||
}
|
||||
|
||||
func (d ensureDevices) run(cmd string, args ...string) error {
|
||||
paths, err := d.lookup.Locate(cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error finding command %v: %v", cmd, err)
|
||||
}
|
||||
if len(paths) == 0 {
|
||||
return fmt.Errorf("command %v not found in path", cmd)
|
||||
}
|
||||
path := paths[0]
|
||||
|
||||
d.logger.Debugf("Running %v", append([]string{path}, args...))
|
||||
err = exec.Command(path, args...).Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running %v: %v", append([]string{path}, args...), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
22
pkg/ensure/ensure.go
Normal file
22
pkg/ensure/ensure.go
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package ensure
|
||||
|
||||
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
|
||||
// Ensure is an alias for Discover
|
||||
type Ensure discover.Discover
|
||||
41
pkg/filter/all.go
Normal file
41
pkg/filter/all.go
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
|
||||
type all struct {
|
||||
selectors []Selector
|
||||
}
|
||||
|
||||
// All returns a selector that evaluates true if EACH of the specified selectors
|
||||
// are selected.
|
||||
func All(selectors ...Selector) Selector {
|
||||
s := all{
|
||||
selectors: selectors,
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s all) Selected(device discover.Device) bool {
|
||||
for _, si := range s.selectors {
|
||||
if !si.Selected(device) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
76
pkg/filter/all_test.go
Normal file
76
pkg/filter/all_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
True := &SelectorMock{
|
||||
SelectedFunc: func(discover.Device) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
False := &SelectorMock{
|
||||
SelectedFunc: func(discover.Device) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
d := discover.Device{}
|
||||
|
||||
// Ensure that the mocks are set up correctly:
|
||||
require.True(t, True.Selected(d))
|
||||
require.False(t, False.Selected(d))
|
||||
|
||||
emtpy := All()
|
||||
require.True(t, emtpy.Selected(d))
|
||||
|
||||
s00 := All(False, False)
|
||||
require.False(t, s00.Selected(d))
|
||||
|
||||
s01 := All(False, True)
|
||||
require.False(t, s01.Selected(d))
|
||||
|
||||
s10 := All(True, False)
|
||||
require.False(t, s10.Selected(d))
|
||||
|
||||
s11 := All(True, True)
|
||||
require.True(t, s11.Selected(d))
|
||||
}
|
||||
|
||||
type discoverMock struct {
|
||||
discover.None
|
||||
devices []discover.Device
|
||||
devicesError error
|
||||
mounts []discover.Mount
|
||||
mountsError error
|
||||
}
|
||||
|
||||
var _ discover.Discover = (*discoverMock)(nil)
|
||||
|
||||
func (m discoverMock) Devices() ([]discover.Device, error) {
|
||||
return m.devices, m.devicesError
|
||||
}
|
||||
|
||||
func (m discoverMock) Mounts() ([]discover.Mount, error) {
|
||||
return m.mounts, m.mountsError
|
||||
}
|
||||
41
pkg/filter/any.go
Normal file
41
pkg/filter/any.go
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
|
||||
type any struct {
|
||||
selectors []Selector
|
||||
}
|
||||
|
||||
// Any returns a selector that evaluates true if ANY of the specified selectors
|
||||
// are selected
|
||||
func Any(selectors ...Selector) Selector {
|
||||
s := any{
|
||||
selectors: selectors,
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s any) Selected(device discover.Device) bool {
|
||||
for _, si := range s.selectors {
|
||||
if si.Selected(device) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
58
pkg/filter/any_test.go
Normal file
58
pkg/filter/any_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. Any rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
func TestAny(t *testing.T) {
|
||||
True := &SelectorMock{
|
||||
SelectedFunc: func(discover.Device) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
False := &SelectorMock{
|
||||
SelectedFunc: func(discover.Device) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
d := discover.Device{}
|
||||
|
||||
// Ensure that the mocks are set up correctly:
|
||||
require.True(t, True.Selected(d))
|
||||
require.False(t, False.Selected(d))
|
||||
|
||||
emtpy := Any()
|
||||
require.False(t, emtpy.Selected(d))
|
||||
|
||||
s00 := Any(False, False)
|
||||
require.False(t, s00.Selected(d))
|
||||
|
||||
s01 := Any(False, True)
|
||||
require.True(t, s01.Selected(d))
|
||||
|
||||
s10 := Any(True, False)
|
||||
require.True(t, s10.Selected(d))
|
||||
|
||||
s11 := Any(True, True)
|
||||
require.True(t, s11.Selected(d))
|
||||
}
|
||||
60
pkg/filter/by_id.go
Normal file
60
pkg/filter/by_id.go
Normal file
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
type devicesByID map[string]struct{}
|
||||
|
||||
var _ Selector = (*devicesByID)(nil)
|
||||
|
||||
// NewDeviceSelector creates a selector for devices based on the specified IDs.
|
||||
func NewDeviceSelector(ids ...string) Selector {
|
||||
deviceIDs := make(devicesByID)
|
||||
|
||||
for _, id := range ids {
|
||||
deviceIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
return deviceIDs
|
||||
}
|
||||
|
||||
// Selected checks whether a specific device is included in the set of devicesIDs
|
||||
// The device is checked by UUID, Index, and PCIBusID and if any of these match
|
||||
// the device is considered selected.
|
||||
func (d devicesByID) Selected(device discover.Device) bool {
|
||||
var exists bool
|
||||
|
||||
_, exists = d[device.UUID]
|
||||
if exists {
|
||||
return true
|
||||
}
|
||||
|
||||
_, exists = d[device.Index]
|
||||
if exists {
|
||||
return true
|
||||
}
|
||||
|
||||
_, exists = d[device.PCIBusID.String()]
|
||||
if exists {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
47
pkg/filter/by_id_test.go
Normal file
47
pkg/filter/by_id_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
func TestDeviceByID(t *testing.T) {
|
||||
device := discover.Device{
|
||||
Index: "index",
|
||||
UUID: "uuid",
|
||||
PCIBusID: discover.PCIBusID("pcibusid"),
|
||||
}
|
||||
|
||||
require.False(t, NewDeviceSelector().Selected(device))
|
||||
require.False(t, NewDeviceSelector("notindex", "notuuid", "notpcibusid").Selected(device))
|
||||
|
||||
require.True(t, NewDeviceSelector("index").Selected(device))
|
||||
require.True(t, NewDeviceSelector("notindex", "index").Selected(device))
|
||||
|
||||
require.True(t, NewDeviceSelector("uuid").Selected(device))
|
||||
require.True(t, NewDeviceSelector("notuuid", "uuid").Selected(device))
|
||||
|
||||
require.True(t, NewDeviceSelector("pcibusid").Selected(device))
|
||||
require.True(t, NewDeviceSelector("notpcibusid", "pcibusid").Selected(device))
|
||||
|
||||
require.True(t, NewDeviceSelector("index", "uuid", "pcibusid").Selected(device))
|
||||
|
||||
}
|
||||
90
pkg/filter/devices.go
Normal file
90
pkg/filter/devices.go
Normal file
@@ -0,0 +1,90 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
const (
|
||||
visibleDevicesAll = "all"
|
||||
visibleDevicesNone = "none"
|
||||
visibleDevicesVoid = "void"
|
||||
)
|
||||
|
||||
type devices struct {
|
||||
discover.Discover
|
||||
logger *log.Logger
|
||||
selector Selector
|
||||
}
|
||||
|
||||
var _ discover.Discover = (*devices)(nil)
|
||||
|
||||
// NewSelectDevicesFrom creates a filter that selects devices based on the value of the
|
||||
// visible devices string.
|
||||
func NewSelectDevicesFrom(d discover.Discover, visibleDevices string, env EnvLookup) discover.Discover {
|
||||
return NewSelectDevicesFromWithLogger(log.StandardLogger(), d, visibleDevices, env)
|
||||
}
|
||||
|
||||
// NewSelectDevicesFromWithLogger creates a filter as for NewSelectDevicesFrom with the
|
||||
// specified logger.
|
||||
func NewSelectDevicesFromWithLogger(logger *log.Logger, d discover.Discover, visibleDevices string, env EnvLookup) discover.Discover {
|
||||
if visibleDevices == "" || visibleDevices == visibleDevicesNone || visibleDevices == visibleDevicesVoid {
|
||||
return &discover.None{}
|
||||
}
|
||||
|
||||
var visibleDeviceSelector Selector
|
||||
if visibleDevices == visibleDevicesAll {
|
||||
visibleDeviceSelector = StandardDevice()
|
||||
} else {
|
||||
visibleDeviceSelector = All(StandardDevice(), NewDeviceSelector(strings.Split(visibleDevices, ",")...))
|
||||
}
|
||||
|
||||
controlDeviceIds := getControlDeviceIDsFromEnvWithLogger(logger, env)
|
||||
controlDeviceSelector := All(ControlDevice(), NewDeviceSelector(controlDeviceIds...))
|
||||
|
||||
vd := devices{
|
||||
Discover: d,
|
||||
logger: logger,
|
||||
selector: Any(visibleDeviceSelector, controlDeviceSelector),
|
||||
}
|
||||
|
||||
return &vd
|
||||
}
|
||||
|
||||
// Devices returns the list of selected devices after filtering based on the
|
||||
// configured selector
|
||||
func (d devices) Devices() ([]discover.Device, error) {
|
||||
devices, err := d.Discover.Devices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering devices: %v", err)
|
||||
}
|
||||
|
||||
var selected []discover.Device
|
||||
for _, di := range devices {
|
||||
if d.selector.Selected(di) {
|
||||
d.logger.Infof("selecting device=%v", di)
|
||||
selected = append(selected, di)
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
104
pkg/filter/devices_test.go
Normal file
104
pkg/filter/devices_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
func TestConstructor(t *testing.T) {
|
||||
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
device0 := discover.Device{
|
||||
Index: "0",
|
||||
UUID: "0",
|
||||
PCIBusID: discover.PCIBusID("0"),
|
||||
}
|
||||
device1 := discover.Device{
|
||||
Index: "1",
|
||||
UUID: "1",
|
||||
PCIBusID: discover.PCIBusID("1"),
|
||||
}
|
||||
device2 := discover.Device{
|
||||
Index: "2",
|
||||
UUID: "2",
|
||||
PCIBusID: discover.PCIBusID("2"),
|
||||
}
|
||||
device3 := discover.Device{
|
||||
Index: "3",
|
||||
UUID: "3",
|
||||
PCIBusID: discover.PCIBusID("3"),
|
||||
}
|
||||
controlDevice := discover.Device{
|
||||
UUID: "CONTROL",
|
||||
}
|
||||
mockDevices := []discover.Device{
|
||||
device0,
|
||||
device1,
|
||||
device2,
|
||||
device3,
|
||||
controlDevice,
|
||||
}
|
||||
d := discoverMock{
|
||||
devices: mockDevices,
|
||||
}
|
||||
|
||||
var ok bool
|
||||
|
||||
withDefaultLogger, ok := NewSelectDevicesFrom(d, "all", nil).(*devices)
|
||||
require.True(t, ok)
|
||||
require.Same(t, log.StandardLogger(), withDefaultLogger.logger)
|
||||
|
||||
_, ok = NewSelectDevicesFromWithLogger(logger, d, "", nil).(*discover.None)
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = NewSelectDevicesFromWithLogger(logger, d, "void", nil).(*discover.None)
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = NewSelectDevicesFromWithLogger(logger, d, "none", nil).(*discover.None)
|
||||
require.True(t, ok)
|
||||
|
||||
all, ok := NewSelectDevicesFromWithLogger(logger, d, "all", nil).(*devices)
|
||||
require.True(t, ok)
|
||||
|
||||
devs, err := all.Devices()
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, mockDevices, devs)
|
||||
|
||||
f, ok := NewSelectDevicesFromWithLogger(logger, d, "0", nil).(*devices)
|
||||
require.True(t, ok)
|
||||
|
||||
devs, err = f.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, devs, 2)
|
||||
require.ElementsMatch(t, devs, []discover.Device{device0, controlDevice})
|
||||
|
||||
f, ok = NewSelectDevicesFromWithLogger(logger, d, "0,2", nil).(*devices)
|
||||
require.True(t, ok)
|
||||
|
||||
devs, err = f.Devices()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, devs, 3)
|
||||
|
||||
require.ElementsMatch(t, devs, []discover.Device{device0, device2, controlDevice})
|
||||
}
|
||||
26
pkg/filter/filter.go
Normal file
26
pkg/filter/filter.go
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
//go:generate moq -stub -out filter_mock.go . EnvLookup
|
||||
|
||||
// EnvLookup defines an interface that supports the LookupEnv function for getting
|
||||
// environment variable values.
|
||||
// TODO: This belongs in a different package
|
||||
type EnvLookup interface {
|
||||
LookupEnv(string) (string, bool)
|
||||
}
|
||||
77
pkg/filter/filter_mock.go
Normal file
77
pkg/filter/filter_mock.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Ensure, that EnvLookupMock does implement EnvLookup.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ EnvLookup = &EnvLookupMock{}
|
||||
|
||||
// EnvLookupMock is a mock implementation of EnvLookup.
|
||||
//
|
||||
// func TestSomethingThatUsesEnvLookup(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked EnvLookup
|
||||
// mockedEnvLookup := &EnvLookupMock{
|
||||
// LookupEnvFunc: func(s string) (string, bool) {
|
||||
// panic("mock out the LookupEnv method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedEnvLookup in code that requires EnvLookup
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type EnvLookupMock struct {
|
||||
// LookupEnvFunc mocks the LookupEnv method.
|
||||
LookupEnvFunc func(s string) (string, bool)
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// LookupEnv holds details about calls to the LookupEnv method.
|
||||
LookupEnv []struct {
|
||||
// S is the s argument value.
|
||||
S string
|
||||
}
|
||||
}
|
||||
lockLookupEnv sync.RWMutex
|
||||
}
|
||||
|
||||
// LookupEnv calls LookupEnvFunc.
|
||||
func (mock *EnvLookupMock) LookupEnv(s string) (string, bool) {
|
||||
callInfo := struct {
|
||||
S string
|
||||
}{
|
||||
S: s,
|
||||
}
|
||||
mock.lockLookupEnv.Lock()
|
||||
mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo)
|
||||
mock.lockLookupEnv.Unlock()
|
||||
if mock.LookupEnvFunc == nil {
|
||||
var (
|
||||
sOut string
|
||||
bOut bool
|
||||
)
|
||||
return sOut, bOut
|
||||
}
|
||||
return mock.LookupEnvFunc(s)
|
||||
}
|
||||
|
||||
// LookupEnvCalls gets all the calls that were made to LookupEnv.
|
||||
// Check the length with:
|
||||
// len(mockedEnvLookup.LookupEnvCalls())
|
||||
func (mock *EnvLookupMock) LookupEnvCalls() []struct {
|
||||
S string
|
||||
} {
|
||||
var calls []struct {
|
||||
S string
|
||||
}
|
||||
mock.lockLookupEnv.RLock()
|
||||
calls = mock.calls.LookupEnv
|
||||
mock.lockLookupEnv.RUnlock()
|
||||
return calls
|
||||
}
|
||||
107
pkg/filter/is_control_device.go
Normal file
107
pkg/filter/is_control_device.go
Normal file
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
const (
|
||||
devicesAll = "all"
|
||||
)
|
||||
|
||||
type controlDevices struct {
|
||||
discover.Discover
|
||||
logger *log.Logger
|
||||
selector Selector
|
||||
}
|
||||
|
||||
var _ discover.Discover = (*controlDevices)(nil)
|
||||
|
||||
// NewControlDevicesFrom creates a filter that selects devices based on the value of the
|
||||
// visible devices string.
|
||||
func NewControlDevicesFrom(d discover.Discover, env EnvLookup) Selector {
|
||||
return NewControlDevicesFromWithLogger(log.StandardLogger(), d, env)
|
||||
}
|
||||
|
||||
// NewControlDevicesFromWithLogger creates a filter as for NewControlDevicesFrom with the
|
||||
// specified logger.
|
||||
func NewControlDevicesFromWithLogger(logger *log.Logger, d discover.Discover, env EnvLookup) Selector {
|
||||
controlDevices := getControlDeviceIDsFromEnvWithLogger(logger, env)
|
||||
return NewDeviceSelector(controlDevices...)
|
||||
}
|
||||
|
||||
type controlDevice struct{}
|
||||
|
||||
// ControlDevice returns a selector for control devices
|
||||
func ControlDevice() Selector {
|
||||
return &controlDevice{}
|
||||
}
|
||||
|
||||
// Selected returns true for a controll devices and false for standard devices. A control device
|
||||
// has an empty index and PCI bus ID and a non-empty UUID.
|
||||
func (s controlDevice) Selected(device discover.Device) bool {
|
||||
if device.Index != "" {
|
||||
return false
|
||||
}
|
||||
if device.PCIBusID != "" {
|
||||
return false
|
||||
}
|
||||
if device.UUID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getControlDeviceIDsFromEnvWithLogger(logger *log.Logger, env EnvLookup) []string {
|
||||
controlDevices := []string{discover.ControlDeviceUUID}
|
||||
|
||||
migControlDevices := getMIGControlDevicesFromEnvWithLogger(logger, env)
|
||||
|
||||
return append(controlDevices, migControlDevices...)
|
||||
}
|
||||
|
||||
func getMIGControlDevicesFromEnvWithLogger(logger *log.Logger, env EnvLookup) []string {
|
||||
if env == nil {
|
||||
logger.Debugf("Environment not specified; no MIG Control devices selected")
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var controlDevices []string
|
||||
|
||||
// Add MIG control devices
|
||||
migEnvUUIDMap := map[string]string{
|
||||
discover.MIGConfigDeviceUUID: "NVIDIA_MIG_CONFIG_DEVICES",
|
||||
discover.MIGMonitorDeviceUUID: "NVIDIA_MIG_MONITOR_DEVICES",
|
||||
}
|
||||
for uuid, migEnv := range migEnvUUIDMap {
|
||||
config, exists := env.LookupEnv(migEnv)
|
||||
if !exists {
|
||||
logger.Debugf("Envvar %v not set", migEnv)
|
||||
continue
|
||||
}
|
||||
if config == devicesAll {
|
||||
logger.Infof("Found %v=%v; selecting MIG %v devices", migEnv, config, uuid)
|
||||
controlDevices = append(controlDevices, uuid)
|
||||
} else {
|
||||
logger.Debugf("Found %v=%v; Skipping MIG %v devices (%v != %v)", migEnv, config, uuid, config, devicesAll)
|
||||
}
|
||||
}
|
||||
return controlDevices
|
||||
}
|
||||
123
pkg/filter/is_control_device_test.go
Normal file
123
pkg/filter/is_control_device_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
func TestControlDevice(t *testing.T) {
|
||||
control := ControlDevice()
|
||||
|
||||
pcibusID := discover.PCIBusID("pcibusid")
|
||||
device := discover.Device{
|
||||
Index: "index",
|
||||
UUID: "uuid",
|
||||
PCIBusID: pcibusID,
|
||||
}
|
||||
require.False(t, control.Selected(device))
|
||||
|
||||
require.False(t, control.Selected(
|
||||
discover.Device{UUID: "uuid", PCIBusID: pcibusID},
|
||||
))
|
||||
|
||||
require.False(t, control.Selected(
|
||||
discover.Device{Index: "index", PCIBusID: pcibusID},
|
||||
))
|
||||
|
||||
require.False(t, control.Selected(
|
||||
discover.Device{Index: "index", UUID: "uuid"},
|
||||
))
|
||||
|
||||
require.False(t, control.Selected(discover.Device{}))
|
||||
|
||||
require.True(t, control.Selected(discover.Device{UUID: "uuid"}))
|
||||
}
|
||||
|
||||
func TestGetControlDevicesFromEnv(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
env map[string]string
|
||||
expectedIds []string
|
||||
}{
|
||||
{
|
||||
name: "nil environment",
|
||||
env: nil,
|
||||
expectedIds: []string{"CONTROL"},
|
||||
},
|
||||
{
|
||||
name: "empty environment",
|
||||
env: map[string]string{},
|
||||
expectedIds: []string{"CONTROL"},
|
||||
},
|
||||
{
|
||||
name: "MIG monitor blank",
|
||||
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": ""},
|
||||
expectedIds: []string{"CONTROL"},
|
||||
},
|
||||
{
|
||||
name: "MIG config blank",
|
||||
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": ""},
|
||||
expectedIds: []string{"CONTROL"},
|
||||
},
|
||||
{
|
||||
name: "MIG monitor not all",
|
||||
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": "not-all"},
|
||||
expectedIds: []string{"CONTROL"},
|
||||
},
|
||||
{
|
||||
name: "MIG config not all",
|
||||
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "not-all"},
|
||||
expectedIds: []string{"CONTROL"},
|
||||
},
|
||||
{
|
||||
name: "MIG monitor all",
|
||||
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": "all"},
|
||||
expectedIds: []string{"CONTROL", "MONITOR"},
|
||||
},
|
||||
{
|
||||
name: "MIG config all",
|
||||
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "all"},
|
||||
expectedIds: []string{"CONTROL", "CONFIG"},
|
||||
},
|
||||
{
|
||||
name: "MIG config and monitor all",
|
||||
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "all", "NVIDIA_MIG_MONITOR_DEVICES": "all"},
|
||||
expectedIds: []string{"CONTROL", "CONFIG", "MONITOR"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
t.Run(fmt.Sprintf("%d: %s", i, tc.name), func(t *testing.T) {
|
||||
deviceIDs := getControlDeviceIDsFromEnvWithLogger(logger, &EnvLookupMock{
|
||||
LookupEnvFunc: func(s string) (string, bool) {
|
||||
value, exists := tc.env[s]
|
||||
return value, exists
|
||||
},
|
||||
})
|
||||
require.ElementsMatch(t, tc.expectedIds, deviceIDs)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
41
pkg/filter/is_standard_device.go
Normal file
41
pkg/filter/is_standard_device.go
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
|
||||
type standardDevice struct{}
|
||||
|
||||
// StandardDevice returns a selector for regular (non-control) devices
|
||||
func StandardDevice() Selector {
|
||||
return &standardDevice{}
|
||||
}
|
||||
|
||||
// Selected returns true for a standard device and false for controll devices. A regular device
|
||||
// is expected to have an index, uuid, and PCI bus ID.
|
||||
func (s standardDevice) Selected(device discover.Device) bool {
|
||||
if device.Index == "" {
|
||||
return false
|
||||
}
|
||||
if device.PCIBusID == "" {
|
||||
return false
|
||||
}
|
||||
if device.UUID == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
52
pkg/filter/is_standard_device_test.go
Normal file
52
pkg/filter/is_standard_device_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
)
|
||||
|
||||
func TestStandardDevice(t *testing.T) {
|
||||
standard := StandardDevice()
|
||||
|
||||
pcibusID := discover.PCIBusID("pcibusid")
|
||||
|
||||
device := discover.Device{
|
||||
Index: "index",
|
||||
UUID: "uuid",
|
||||
PCIBusID: pcibusID,
|
||||
}
|
||||
require.True(t, standard.Selected(device))
|
||||
|
||||
require.False(t, standard.Selected(
|
||||
discover.Device{UUID: "uuid", PCIBusID: pcibusID},
|
||||
))
|
||||
|
||||
require.False(t, standard.Selected(
|
||||
discover.Device{Index: "index", PCIBusID: pcibusID},
|
||||
))
|
||||
|
||||
require.False(t, standard.Selected(
|
||||
discover.Device{Index: "index", UUID: "uuid"},
|
||||
))
|
||||
|
||||
require.False(t, standard.Selected(discover.Device{}))
|
||||
|
||||
}
|
||||
27
pkg/filter/selector.go
Normal file
27
pkg/filter/selector.go
Normal file
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package filter
|
||||
|
||||
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
|
||||
//go:generate moq -stub -out selector_mock.go . Selector
|
||||
|
||||
// Selector defines an interface for determining whether a specfied Device is selected
|
||||
// by a particular configuration.
|
||||
type Selector interface {
|
||||
Selected(discover.Device) bool
|
||||
}
|
||||
77
pkg/filter/selector_mock.go
Normal file
77
pkg/filter/selector_mock.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Ensure, that SelectorMock does implement Selector.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ Selector = &SelectorMock{}
|
||||
|
||||
// SelectorMock is a mock implementation of Selector.
|
||||
//
|
||||
// func TestSomethingThatUsesSelector(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked Selector
|
||||
// mockedSelector := &SelectorMock{
|
||||
// SelectedFunc: func(device discover.Device) bool {
|
||||
// panic("mock out the Selected method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedSelector in code that requires Selector
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type SelectorMock struct {
|
||||
// SelectedFunc mocks the Selected method.
|
||||
SelectedFunc func(device discover.Device) bool
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// Selected holds details about calls to the Selected method.
|
||||
Selected []struct {
|
||||
// Device is the device argument value.
|
||||
Device discover.Device
|
||||
}
|
||||
}
|
||||
lockSelected sync.RWMutex
|
||||
}
|
||||
|
||||
// Selected calls SelectedFunc.
|
||||
func (mock *SelectorMock) Selected(device discover.Device) bool {
|
||||
callInfo := struct {
|
||||
Device discover.Device
|
||||
}{
|
||||
Device: device,
|
||||
}
|
||||
mock.lockSelected.Lock()
|
||||
mock.calls.Selected = append(mock.calls.Selected, callInfo)
|
||||
mock.lockSelected.Unlock()
|
||||
if mock.SelectedFunc == nil {
|
||||
var (
|
||||
bOut bool
|
||||
)
|
||||
return bOut
|
||||
}
|
||||
return mock.SelectedFunc(device)
|
||||
}
|
||||
|
||||
// SelectedCalls gets all the calls that were made to Selected.
|
||||
// Check the length with:
|
||||
// len(mockedSelector.SelectedCalls())
|
||||
func (mock *SelectorMock) SelectedCalls() []struct {
|
||||
Device discover.Device
|
||||
} {
|
||||
var calls []struct {
|
||||
Device discover.Device
|
||||
}
|
||||
mock.lockSelected.RLock()
|
||||
calls = mock.calls.Selected
|
||||
mock.lockSelected.RUnlock()
|
||||
return calls
|
||||
}
|
||||
118
pkg/modify/device.go
Normal file
118
pkg/modify/device.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package modify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
)
|
||||
|
||||
// Device is an alias to discover.Device that allows for addition of a Modify method
|
||||
type Device struct {
|
||||
logger *log.Logger
|
||||
discover.Device
|
||||
}
|
||||
|
||||
// ProcMount is an alias to discover.Mount that allows for the addition of a Modify method for
|
||||
// proc paths associated with devices
|
||||
type ProcMount struct {
|
||||
logger *log.Logger
|
||||
discover.ProcPath
|
||||
}
|
||||
|
||||
var _ Modifier = (*Device)(nil)
|
||||
var _ Modifier = (*ProcMount)(nil)
|
||||
|
||||
// Modify applies the modifications required by a Device to the specified OCI specification
|
||||
func (d Device) Modify(spec oci.Spec) error {
|
||||
for _, dn := range d.DeviceNodes {
|
||||
mi := deviceNode{
|
||||
logger: d.logger,
|
||||
DeviceNode: dn,
|
||||
}
|
||||
err := mi.Modify(spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not inject device node %v: %v", dn, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range d.ProcPaths {
|
||||
mi := ProcMount{
|
||||
logger: d.logger,
|
||||
ProcPath: p,
|
||||
}
|
||||
err := mi.Modify(spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not inject proc path %v: %v", p, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type deviceNode struct {
|
||||
logger *log.Logger
|
||||
discover.DeviceNode
|
||||
}
|
||||
|
||||
func (d deviceNode) Modify(spec oci.Spec) error {
|
||||
return spec.Modify(d.specModifier)
|
||||
}
|
||||
|
||||
func (d deviceNode) specModifier(spec *specs.Spec) error {
|
||||
if spec.Linux == nil {
|
||||
d.logger.Debugf("Initializing spec.Linux")
|
||||
spec.Linux = &specs.Linux{}
|
||||
}
|
||||
if spec.Linux.Resources == nil {
|
||||
d.logger.Debugf("Initializing spec.LinuxResources")
|
||||
spec.Linux.Resources = &specs.LinuxResources{}
|
||||
}
|
||||
|
||||
// TODO: These need to be configurable
|
||||
deviceFileMode := os.FileMode(8630)
|
||||
deviceUID := uint32(0)
|
||||
deviceGID := uint32(0)
|
||||
|
||||
deviceMajor := int64(d.Major)
|
||||
deviceMinor := int64(d.Minor)
|
||||
|
||||
d.logger.Infof("Adding device %v", d.Path)
|
||||
ociDevice := specs.LinuxDevice{
|
||||
Path: string(d.Path),
|
||||
Type: "c",
|
||||
Major: deviceMajor,
|
||||
Minor: deviceMinor,
|
||||
FileMode: &deviceFileMode,
|
||||
UID: &deviceUID,
|
||||
GID: &deviceGID,
|
||||
}
|
||||
spec.Linux.Devices = append(spec.Linux.Devices, ociDevice)
|
||||
|
||||
ociDeviceCgroup := specs.LinuxDeviceCgroup{
|
||||
Allow: true,
|
||||
Type: "c",
|
||||
Major: &deviceMajor,
|
||||
Minor: &deviceMinor,
|
||||
Access: "rwm",
|
||||
}
|
||||
|
||||
// TODO: We have to handle the case where we are updating the cgroups for multiple devices
|
||||
// leading to duplicates in the spec
|
||||
spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, ociDeviceCgroup)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify applies the modifications required for a Mount to the specified OCI specification
|
||||
func (m ProcMount) Modify(spec oci.Spec) error {
|
||||
return spec.Modify(m.specModifier)
|
||||
}
|
||||
|
||||
func (m ProcMount) specModifier(spec *specs.Spec) error {
|
||||
m.logger.Infof("Mounting read-only proc path %v", m.ProcPath)
|
||||
spec.Linux.ReadonlyPaths = append(spec.Linux.ReadonlyPaths, string(m.ProcPath))
|
||||
return nil
|
||||
}
|
||||
155
pkg/modify/discover.go
Normal file
155
pkg/modify/discover.go
Normal file
@@ -0,0 +1,155 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package modify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
)
|
||||
|
||||
type discoverModifier struct {
|
||||
logger *log.Logger
|
||||
discover discover.Discover
|
||||
root string
|
||||
bundleDir string
|
||||
}
|
||||
|
||||
var _ Modifier = (*discoverModifier)(nil)
|
||||
|
||||
// NewModifierFor creates a Modifier that can be used to apply the modifications to an OCI specification
|
||||
// required by the specified Discover instance.
|
||||
func NewModifierFor(discover discover.Discover, root string, bundleDir string) Modifier {
|
||||
return NewModifierWithLoggerFor(log.StandardLogger(), discover, root, bundleDir)
|
||||
}
|
||||
|
||||
// NewModifierWithLoggerFor creates a Modifier that can be used to apply the modifications to an OCI specification
|
||||
// required by the specified Discover instance.
|
||||
func NewModifierWithLoggerFor(logger *log.Logger, discover discover.Discover, root string, bundleDir string) Modifier {
|
||||
m := discoverModifier{
|
||||
logger: logger,
|
||||
discover: discover,
|
||||
root: root,
|
||||
bundleDir: bundleDir,
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
|
||||
// Modify applies the modifications for the discovered devices, mounts, etc. to the specified
|
||||
// OCI spec.
|
||||
func (m discoverModifier) Modify(spec oci.Spec) error {
|
||||
m.logger.Infof("Determining required OCI spec modifications")
|
||||
modifiers, err := m.modifiers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error constructing modifiers: %v", err)
|
||||
}
|
||||
|
||||
m.logger.Infof("Applying %v modifications", len(modifiers))
|
||||
for _, mi := range modifiers {
|
||||
err := mi.Modify(spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not apply modifier %v: %v", mi, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m discoverModifier) modifiers() ([]Modifier, error) {
|
||||
var modifiers []Modifier
|
||||
|
||||
deviceModifiers, err := m.deviceModifiers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modifiers = append(modifiers, deviceModifiers...)
|
||||
|
||||
mountModifiers, err := m.mountModifiers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modifiers = append(modifiers, mountModifiers...)
|
||||
|
||||
hookModifiers, err := m.hookModifiers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modifiers = append(modifiers, hookModifiers...)
|
||||
|
||||
return modifiers, nil
|
||||
}
|
||||
|
||||
func (m discoverModifier) deviceModifiers() ([]Modifier, error) {
|
||||
var modifiers []Modifier
|
||||
|
||||
devices, err := m.discover.Devices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering devices: %v", err)
|
||||
}
|
||||
|
||||
for _, d := range devices {
|
||||
m := Device{
|
||||
logger: m.logger,
|
||||
Device: d,
|
||||
}
|
||||
modifiers = append(modifiers, m)
|
||||
}
|
||||
|
||||
return modifiers, nil
|
||||
}
|
||||
|
||||
func (m discoverModifier) mountModifiers() ([]Modifier, error) {
|
||||
var modifiers []Modifier
|
||||
|
||||
mounts, err := m.discover.Mounts()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering mounts: %v", err)
|
||||
}
|
||||
|
||||
for _, mi := range mounts {
|
||||
mm := Mount{
|
||||
logger: m.logger,
|
||||
Mount: mi,
|
||||
root: m.root,
|
||||
}
|
||||
modifiers = append(modifiers, mm)
|
||||
}
|
||||
|
||||
return modifiers, nil
|
||||
}
|
||||
|
||||
func (m discoverModifier) hookModifiers() ([]Modifier, error) {
|
||||
var modifiers []Modifier
|
||||
|
||||
hooks, err := m.discover.Hooks()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error discovering hooks: %v", err)
|
||||
}
|
||||
|
||||
for _, h := range hooks {
|
||||
m := Hook{
|
||||
logger: m.logger,
|
||||
Hook: h,
|
||||
bundleDir: m.bundleDir,
|
||||
}
|
||||
modifiers = append(modifiers, m)
|
||||
}
|
||||
|
||||
return modifiers, nil
|
||||
}
|
||||
81
pkg/modify/hooks.go
Normal file
81
pkg/modify/hooks.go
Normal file
@@ -0,0 +1,81 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package modify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
)
|
||||
|
||||
// Hook is an alias to discover.Hook that allows for addition of a Modify method
|
||||
type Hook struct {
|
||||
logger *log.Logger
|
||||
discover.Hook
|
||||
bundleDir string
|
||||
}
|
||||
|
||||
var _ Modifier = (*Hook)(nil)
|
||||
|
||||
// Modify applies the modifications required by a Hook to the specified OCI specification
|
||||
func (h Hook) Modify(spec oci.Spec) error {
|
||||
return spec.Modify(h.specModifier)
|
||||
}
|
||||
|
||||
func (h Hook) specModifier(spec *specs.Spec) error {
|
||||
if spec.Hooks == nil {
|
||||
h.logger.Debugf("Initializing spec.Hooks")
|
||||
spec.Hooks = &specs.Hooks{}
|
||||
}
|
||||
|
||||
// TODO: This is duplicated in the hook specification
|
||||
const rootPattern = "@Root.Path@"
|
||||
|
||||
rootPath := spec.Root.Path
|
||||
if !filepath.IsAbs(rootPath) {
|
||||
rootPath = filepath.Join(h.bundleDir, rootPath)
|
||||
}
|
||||
|
||||
var args []string
|
||||
for _, a := range h.Args {
|
||||
if strings.Contains(a, rootPattern) {
|
||||
args = append(args, strings.ReplaceAll(a, rootPattern, rootPath))
|
||||
continue
|
||||
}
|
||||
args = append(args, a)
|
||||
}
|
||||
|
||||
specHook := specs.Hook{
|
||||
Path: h.Path,
|
||||
Args: args,
|
||||
}
|
||||
|
||||
h.logger.Infof("Adding %v hook %+v", h.HookName, specHook)
|
||||
switch h.HookName {
|
||||
case "create-container":
|
||||
spec.Hooks.CreateContainer = append(spec.Hooks.CreateContainer, specHook)
|
||||
default:
|
||||
return fmt.Errorf("unexpected hook name: %v", h.HookName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
28
pkg/modify/modify.go
Normal file
28
pkg/modify/modify.go
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package modify
|
||||
|
||||
import (
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
)
|
||||
|
||||
//go:generate moq -stub -out modify_mock.go . Modifier
|
||||
|
||||
// Modifier defines an interface for modifying an OCI Specification.
|
||||
type Modifier interface {
|
||||
Modify(oci.Spec) error
|
||||
}
|
||||
77
pkg/modify/modify_mock.go
Normal file
77
pkg/modify/modify_mock.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package modify
|
||||
|
||||
import (
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Ensure, that ModifierMock does implement Modifier.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ Modifier = &ModifierMock{}
|
||||
|
||||
// ModifierMock is a mock implementation of Modifier.
|
||||
//
|
||||
// func TestSomethingThatUsesModifier(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked Modifier
|
||||
// mockedModifier := &ModifierMock{
|
||||
// ModifyFunc: func(spec oci.Spec) error {
|
||||
// panic("mock out the Modify method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedModifier in code that requires Modifier
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type ModifierMock struct {
|
||||
// ModifyFunc mocks the Modify method.
|
||||
ModifyFunc func(spec oci.Spec) error
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// Modify holds details about calls to the Modify method.
|
||||
Modify []struct {
|
||||
// Spec is the spec argument value.
|
||||
Spec oci.Spec
|
||||
}
|
||||
}
|
||||
lockModify sync.RWMutex
|
||||
}
|
||||
|
||||
// Modify calls ModifyFunc.
|
||||
func (mock *ModifierMock) Modify(spec oci.Spec) error {
|
||||
callInfo := struct {
|
||||
Spec oci.Spec
|
||||
}{
|
||||
Spec: spec,
|
||||
}
|
||||
mock.lockModify.Lock()
|
||||
mock.calls.Modify = append(mock.calls.Modify, callInfo)
|
||||
mock.lockModify.Unlock()
|
||||
if mock.ModifyFunc == nil {
|
||||
var (
|
||||
errOut error
|
||||
)
|
||||
return errOut
|
||||
}
|
||||
return mock.ModifyFunc(spec)
|
||||
}
|
||||
|
||||
// ModifyCalls gets all the calls that were made to Modify.
|
||||
// Check the length with:
|
||||
// len(mockedModifier.ModifyCalls())
|
||||
func (mock *ModifierMock) ModifyCalls() []struct {
|
||||
Spec oci.Spec
|
||||
} {
|
||||
var calls []struct {
|
||||
Spec oci.Spec
|
||||
}
|
||||
mock.lockModify.RLock()
|
||||
calls = mock.calls.Modify
|
||||
mock.lockModify.RUnlock()
|
||||
return calls
|
||||
}
|
||||
53
pkg/modify/mounts.go
Normal file
53
pkg/modify/mounts.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package modify
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
)
|
||||
|
||||
// Mount is an alias to discover.Mount that allows for addition of a Modify method
|
||||
type Mount struct {
|
||||
logger *log.Logger
|
||||
discover.Mount
|
||||
root string
|
||||
}
|
||||
|
||||
var _ Modifier = (*Mount)(nil)
|
||||
|
||||
// Modify applies the modifications required for a Mount to the specified OCI specification
|
||||
func (d Mount) Modify(spec oci.Spec) error {
|
||||
return spec.Modify(d.specModifier)
|
||||
}
|
||||
|
||||
// TODO: We need to ensure that we are correctly mounting the proc paths
|
||||
// Also — I’m not sure how this is done, but we will need a new tempfs mounted at /proc/driver/nvidia/ underneath which all of these other mounted directories get put
|
||||
// Maybe this?
|
||||
// https://github.com/opencontainers/runtime-spec/blob/master/specs-go/config.go#L175 (edited)
|
||||
// specs-go/config.go:175
|
||||
// MaskedPaths []string `json:"maskedPaths,omitempty"`
|
||||
// <https://github.com/opencontainers/runtime-spec|opencontainers/runtime-spec>opencontainers/runtime-spec | Added by GitHub
|
||||
// 13:53
|
||||
// Proabably, given…
|
||||
// https://github.com/opencontainers/runtime-spec/blob/master/config-linux.md#masked-paths (edited)
|
||||
// TODO: We can try masking all of /proc/driver/nvidia and then mounting the paths read-only
|
||||
func (d Mount) specModifier(spec *specs.Spec) error {
|
||||
source := d.Path
|
||||
destination := strings.TrimPrefix(d.Path, d.root)
|
||||
d.logger.Infof("Mounting %v -> %v", source, destination)
|
||||
mount := specs.Mount{
|
||||
Destination: destination,
|
||||
Source: source,
|
||||
Type: "bind",
|
||||
Options: []string{
|
||||
"rbind",
|
||||
"rprivate",
|
||||
},
|
||||
}
|
||||
spec.Mounts = append(spec.Mounts, mount)
|
||||
|
||||
return nil
|
||||
}
|
||||
135
pkg/oci/args.go
Normal file
135
pkg/oci/args.go
Normal file
@@ -0,0 +1,135 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
specFileName = "config.json"
|
||||
)
|
||||
|
||||
// GetBundleDir returns the bundle directory or default depending on the
|
||||
// supplied command line arguments.
|
||||
func GetBundleDir(args []string) (string, error) {
|
||||
bundleDir, err := GetBundleDirFromArgs(args)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting bundle dir from args: %v", err)
|
||||
}
|
||||
|
||||
if bundleDir != "" {
|
||||
return bundleDir, nil
|
||||
}
|
||||
|
||||
defaultBundleDir, err := GetDefaultBundleDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting default bundle dir: %v", err)
|
||||
}
|
||||
|
||||
return defaultBundleDir, nil
|
||||
}
|
||||
|
||||
// GetBundleDirFromArgs checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
|
||||
// The following are supported:
|
||||
// --bundle{{SEP}}BUNDLE_PATH
|
||||
// -bundle{{SEP}}BUNDLE_PATH
|
||||
// -b{{SEP}}BUNDLE_PATH
|
||||
// where {{SEP}} is either ' ' or '='
|
||||
func GetBundleDirFromArgs(args []string) (string, error) {
|
||||
var bundleDir string
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
param := args[i]
|
||||
|
||||
parts := strings.SplitN(param, "=", 2)
|
||||
if !IsBundleFlag(parts[0]) {
|
||||
continue
|
||||
}
|
||||
|
||||
// The flag has the format --bundle=/path
|
||||
if len(parts) == 2 {
|
||||
bundleDir = parts[1]
|
||||
continue
|
||||
}
|
||||
|
||||
// The flag has the format --bundle /path
|
||||
if i+1 < len(args) {
|
||||
bundleDir = args[i+1]
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// --bundle / -b was the last element of args
|
||||
return "", fmt.Errorf("bundle option requires an argument")
|
||||
}
|
||||
|
||||
return bundleDir, nil
|
||||
}
|
||||
|
||||
// GetDefaultBundleDir returns the bundle directory that is to be used if no alternative is
|
||||
// specified via the command line, for example.
|
||||
func GetDefaultBundleDir() (string, error) {
|
||||
workingDirectory, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting working directory: %v", err)
|
||||
}
|
||||
return workingDirectory, nil
|
||||
}
|
||||
|
||||
// GetSpecFilePath returns the expected path to the OCI specification file for the given
|
||||
// bundle directory.
|
||||
func GetSpecFilePath(bundleDir string) string {
|
||||
specFilePath := filepath.Join(bundleDir, specFileName)
|
||||
return specFilePath
|
||||
}
|
||||
|
||||
// IsBundleFlag is a helper function that checks wither the specified argument represents
|
||||
// a bundle flag (--bundle or -b)
|
||||
func IsBundleFlag(arg string) bool {
|
||||
if !strings.HasPrefix(arg, "-") {
|
||||
return false
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeft(arg, "-")
|
||||
return trimmed == "b" || trimmed == "bundle"
|
||||
}
|
||||
|
||||
// HasCreateSubcommand checks the supplied arguments for a 'create' subcommand
|
||||
func HasCreateSubcommand(args []string) bool {
|
||||
var previousWasBundle bool
|
||||
for _, a := range args {
|
||||
// We check for '--bundle create' explicitly to ensure that we
|
||||
// don't inadvertently trigger a modification if the bundle directory
|
||||
// is specified as `create`
|
||||
if !previousWasBundle && IsBundleFlag(a) {
|
||||
previousWasBundle = true
|
||||
continue
|
||||
}
|
||||
|
||||
if !previousWasBundle && a == "create" {
|
||||
return true
|
||||
}
|
||||
|
||||
previousWasBundle = false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
198
pkg/oci/args_test.go
Normal file
198
pkg/oci/args_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package oci
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBundleDir(t *testing.T) {
|
||||
defaultBundleDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
type expected struct {
|
||||
bundle string
|
||||
isError bool
|
||||
}
|
||||
testCases := []struct {
|
||||
argv []string
|
||||
expected expected
|
||||
}{
|
||||
{
|
||||
argv: []string{},
|
||||
expected: expected{
|
||||
bundle: defaultBundleDir,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"create"},
|
||||
expected: expected{
|
||||
bundle: defaultBundleDir,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle"},
|
||||
expected: expected{
|
||||
isError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b"},
|
||||
expected: expected{
|
||||
isError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--not-bundle", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: defaultBundleDir,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--"},
|
||||
expected: expected{
|
||||
bundle: defaultBundleDir,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-bundle", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle=/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=/foo/=bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/=bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"create", "-b", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "create", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=create", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
bundle, err := GetBundleDir(tc.argv)
|
||||
|
||||
if tc.expected.isError {
|
||||
require.Errorf(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultBundleDir(t *testing.T) {
|
||||
defaultBundleDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
bundleDir, err := GetDefaultBundleDir()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, defaultBundleDir, bundleDir)
|
||||
}
|
||||
|
||||
func TestGetSpecFilePathAppendsFilename(t *testing.T) {
|
||||
testCases := []struct {
|
||||
bundleDir string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
bundleDir: "",
|
||||
expected: "config.json",
|
||||
},
|
||||
{
|
||||
bundleDir: "/not/empty/",
|
||||
expected: "/not/empty/config.json",
|
||||
},
|
||||
{
|
||||
bundleDir: "not/absolute",
|
||||
expected: "not/absolute/config.json",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
specPath := GetSpecFilePath(tc.bundleDir)
|
||||
|
||||
require.Equalf(t, tc.expected, specPath, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCreateSubcommand(t *testing.T) {
|
||||
testCases := []struct {
|
||||
args []string
|
||||
shouldModify bool
|
||||
}{
|
||||
{
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
},
|
||||
{
|
||||
args: []string{"--bundle=create"},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
args: []string{"--bundle", "create"},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
require.Equal(t, tc.shouldModify, HasCreateSubcommand(tc.args), "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
25
pkg/oci/runtime.go
Normal file
25
pkg/oci/runtime.go
Normal file
@@ -0,0 +1,25 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
//go:generate moq -stub -out runtime_mock.go . Runtime
|
||||
|
||||
// Runtime is an interface for a runtime shim. The Exec method accepts a list
|
||||
// of command line arguments, and returns an error / nil.
|
||||
type Runtime interface {
|
||||
Exec([]string) error
|
||||
}
|
||||
61
pkg/oci/runtime_low_level.go
Normal file
61
pkg/oci/runtime_low_level.go
Normal file
@@ -0,0 +1,61 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
|
||||
// The executable specified is taken from the list of supplied candidates, with the first match
|
||||
// present in the PATH being selected.
|
||||
func NewLowLevelRuntime(candidates ...string) (Runtime, error) {
|
||||
return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...)
|
||||
}
|
||||
|
||||
// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger.
|
||||
func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) {
|
||||
runtimePath, err := findRuntime(candidates)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error locating runtime: %v", err)
|
||||
}
|
||||
|
||||
return NewRuntimeForPathWithLogger(logger, runtimePath)
|
||||
}
|
||||
|
||||
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.
|
||||
// The absolute path to the first match is returned.
|
||||
func findRuntime(candidates []string) (string, error) {
|
||||
if len(candidates) == 0 {
|
||||
return "", fmt.Errorf("at least one runtime candidate must be specified")
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
log.Infof("Looking for runtime binary '%v'", candidate)
|
||||
runcPath, err := exec.LookPath(candidate)
|
||||
if err == nil {
|
||||
log.Infof("Found runtime binary '%v'", runcPath)
|
||||
return runcPath, nil
|
||||
}
|
||||
log.Warnf("Runtime binary '%v' not found: %v", candidate, err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no runtime binary found from candidate list: %v", candidates)
|
||||
}
|
||||
76
pkg/oci/runtime_mock.go
Normal file
76
pkg/oci/runtime_mock.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Ensure, that RuntimeMock does implement Runtime.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ Runtime = &RuntimeMock{}
|
||||
|
||||
// RuntimeMock is a mock implementation of Runtime.
|
||||
//
|
||||
// func TestSomethingThatUsesRuntime(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked Runtime
|
||||
// mockedRuntime := &RuntimeMock{
|
||||
// ExecFunc: func(strings []string) error {
|
||||
// panic("mock out the Exec method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedRuntime in code that requires Runtime
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type RuntimeMock struct {
|
||||
// ExecFunc mocks the Exec method.
|
||||
ExecFunc func(strings []string) error
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// Exec holds details about calls to the Exec method.
|
||||
Exec []struct {
|
||||
// Strings is the strings argument value.
|
||||
Strings []string
|
||||
}
|
||||
}
|
||||
lockExec sync.RWMutex
|
||||
}
|
||||
|
||||
// Exec calls ExecFunc.
|
||||
func (mock *RuntimeMock) Exec(strings []string) error {
|
||||
callInfo := struct {
|
||||
Strings []string
|
||||
}{
|
||||
Strings: strings,
|
||||
}
|
||||
mock.lockExec.Lock()
|
||||
mock.calls.Exec = append(mock.calls.Exec, callInfo)
|
||||
mock.lockExec.Unlock()
|
||||
if mock.ExecFunc == nil {
|
||||
var (
|
||||
errOut error
|
||||
)
|
||||
return errOut
|
||||
}
|
||||
return mock.ExecFunc(strings)
|
||||
}
|
||||
|
||||
// ExecCalls gets all the calls that were made to Exec.
|
||||
// Check the length with:
|
||||
// len(mockedRuntime.ExecCalls())
|
||||
func (mock *RuntimeMock) ExecCalls() []struct {
|
||||
Strings []string
|
||||
} {
|
||||
var calls []struct {
|
||||
Strings []string
|
||||
}
|
||||
mock.lockExec.RLock()
|
||||
calls = mock.calls.Exec
|
||||
mock.lockExec.RUnlock()
|
||||
return calls
|
||||
}
|
||||
70
pkg/oci/runtime_path.go
Normal file
70
pkg/oci/runtime_path.go
Normal file
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// pathRuntime wraps the path that a binary and defines the semanitcs for how to exec into it.
|
||||
// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the
|
||||
// Runtime internface.
|
||||
type pathRuntime struct {
|
||||
logger *log.Logger
|
||||
path string
|
||||
execRuntime Runtime
|
||||
}
|
||||
|
||||
var _ Runtime = (*pathRuntime)(nil)
|
||||
|
||||
// NewRuntimeForPath creates a Runtime for the specified path with the standard logger
|
||||
func NewRuntimeForPath(path string) (Runtime, error) {
|
||||
return NewRuntimeForPathWithLogger(log.StandardLogger(), path)
|
||||
}
|
||||
|
||||
// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path
|
||||
func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
|
||||
}
|
||||
if info.IsDir() || info.Mode()&0111 == 0 {
|
||||
return nil, fmt.Errorf("specified path '%v' is not an executable file", path)
|
||||
}
|
||||
|
||||
shim := pathRuntime{
|
||||
logger: logger,
|
||||
path: path,
|
||||
execRuntime: syscallExec{},
|
||||
}
|
||||
|
||||
return &shim, nil
|
||||
}
|
||||
|
||||
// Exec exces into the binary at the path from the pathRuntime struct, passing it the supplied arguments
|
||||
// after ensuring that the first argument is the path of the target binary.
|
||||
func (s pathRuntime) Exec(args []string) error {
|
||||
runtimeArgs := []string{s.path}
|
||||
if len(args) > 1 {
|
||||
runtimeArgs = append(runtimeArgs, args[1:]...)
|
||||
}
|
||||
|
||||
return s.execRuntime.Exec(runtimeArgs)
|
||||
}
|
||||
97
pkg/oci/runtime_path_test.go
Normal file
97
pkg/oci/runtime_path_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPathRuntimeConstructor(t *testing.T) {
|
||||
r, err := NewRuntimeForPath("////an/invalid/path")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewRuntimeForPath("/tmp")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewRuntimeForPath("/dev/null")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewRuntimeForPath("/bin/sh")
|
||||
require.NoError(t, err)
|
||||
|
||||
f, ok := r.(*pathRuntime)
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, "/bin/sh", f.path)
|
||||
}
|
||||
|
||||
func TestPathRuntimeForwardsArgs(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
testCases := []struct {
|
||||
execRuntimeError error
|
||||
args []string
|
||||
}{
|
||||
{},
|
||||
{
|
||||
args: []string{"shouldBeReplaced"},
|
||||
},
|
||||
{
|
||||
args: []string{"shouldBeReplaced", "arg1"},
|
||||
},
|
||||
{
|
||||
execRuntimeError: fmt.Errorf("exec error"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
mockedRuntime := &RuntimeMock{
|
||||
ExecFunc: func(strings []string) error {
|
||||
return tc.execRuntimeError
|
||||
},
|
||||
}
|
||||
r := pathRuntime{
|
||||
logger: logger,
|
||||
path: "runtime",
|
||||
execRuntime: mockedRuntime,
|
||||
}
|
||||
err := r.Exec(tc.args)
|
||||
|
||||
require.ErrorIs(t, err, tc.execRuntimeError)
|
||||
|
||||
calls := mockedRuntime.ExecCalls()
|
||||
require.Len(t, calls, 1)
|
||||
|
||||
numArgs := len(tc.args)
|
||||
if numArgs == 0 {
|
||||
numArgs = 1
|
||||
}
|
||||
|
||||
require.Len(t, calls[0].Strings, numArgs)
|
||||
require.Equal(t, "runtime", calls[0].Strings[0])
|
||||
|
||||
if numArgs > 1 {
|
||||
require.EqualValues(t, tc.args[1:], calls[0].Strings[1:])
|
||||
}
|
||||
}
|
||||
}
|
||||
38
pkg/oci/runtime_syscall_exec.go
Normal file
38
pkg/oci/runtime_syscall_exec.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type syscallExec struct{}
|
||||
|
||||
var _ Runtime = (*syscallExec)(nil)
|
||||
|
||||
func (r syscallExec) Exec(args []string) error {
|
||||
err := syscall.Exec(args[0], args, os.Environ())
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not exec '%v': %v", args[0], err)
|
||||
}
|
||||
|
||||
// syscall.Exec is not expected to return. This is an error state regardless of whether
|
||||
// err is nil or not.
|
||||
return fmt.Errorf("unexpected return from exec '%v'", args[0])
|
||||
}
|
||||
35
pkg/oci/spec.go
Normal file
35
pkg/oci/spec.go
Normal file
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
||||
)
|
||||
|
||||
// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an
|
||||
// error. The intention is that the function would modify the spec in-place.
|
||||
type SpecModifier func(*oci.Spec) error
|
||||
|
||||
//go:generate moq -stub -out spec_mock.go . Spec
|
||||
|
||||
// Spec defines the operations to be performed on an OCI specification
|
||||
type Spec interface {
|
||||
Load() error
|
||||
Flush() error
|
||||
Modify(SpecModifier) error
|
||||
LookupEnv(string) (string, bool)
|
||||
}
|
||||
153
pkg/oci/spec_file.go
Normal file
153
pkg/oci/spec_file.go
Normal file
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
||||
)
|
||||
|
||||
type fileSpec struct {
|
||||
*oci.Spec
|
||||
path string
|
||||
}
|
||||
|
||||
var _ Spec = (*fileSpec)(nil)
|
||||
|
||||
// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the
|
||||
// application
|
||||
func NewSpecFromArgs(args []string) (Spec, string, error) {
|
||||
bundleDir, err := GetBundleDir(args)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error getting bundle directory: %v", err)
|
||||
}
|
||||
|
||||
ociSpecPath := GetSpecFilePath(bundleDir)
|
||||
|
||||
ociSpec := NewSpecFromFile(ociSpecPath)
|
||||
|
||||
return ociSpec, bundleDir, nil
|
||||
}
|
||||
|
||||
// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec.
|
||||
// This can be used to read from the file, modify the spec, and write to the
|
||||
// same file.
|
||||
func NewSpecFromFile(filepath string) Spec {
|
||||
oci := fileSpec{
|
||||
path: filepath,
|
||||
}
|
||||
|
||||
return &oci
|
||||
}
|
||||
|
||||
// Load reads the contents of an OCI spec from file to be referenced internally.
|
||||
// The file is opened "read-only"
|
||||
func (s *fileSpec) Load() error {
|
||||
specFile, err := os.Open(s.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening OCI specification file: %v", err)
|
||||
}
|
||||
defer specFile.Close()
|
||||
|
||||
return s.loadFrom(specFile)
|
||||
}
|
||||
|
||||
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
|
||||
func (s *fileSpec) loadFrom(reader io.Reader) error {
|
||||
decoder := json.NewDecoder(reader)
|
||||
|
||||
var spec oci.Spec
|
||||
err := decoder.Decode(&spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading OCI specification: %v", err)
|
||||
}
|
||||
|
||||
s.Spec = &spec
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify applies the specified SpecModifier to the stored OCI specification.
|
||||
func (s *fileSpec) Modify(f SpecModifier) error {
|
||||
if s.Spec == nil {
|
||||
return fmt.Errorf("no spec loaded for modification")
|
||||
}
|
||||
return f(s.Spec)
|
||||
}
|
||||
|
||||
// Flush writes the stored OCI specification to the filepath specifed by the path member.
|
||||
// The file is truncated upon opening, overwriting any existing contents.
|
||||
func (s fileSpec) Flush() error {
|
||||
if s.Spec == nil {
|
||||
return fmt.Errorf("no OCI specification loaded")
|
||||
}
|
||||
|
||||
specFile, err := os.Create(s.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening OCI specification file: %v", err)
|
||||
}
|
||||
defer specFile.Close()
|
||||
|
||||
return s.flushTo(specFile)
|
||||
}
|
||||
|
||||
// flushTo writes the stored OCI specification to the specified io.Writer.
|
||||
func (s fileSpec) flushTo(writer io.Writer) error {
|
||||
if s.Spec == nil {
|
||||
return nil
|
||||
}
|
||||
encoder := json.NewEncoder(writer)
|
||||
|
||||
err := encoder.Encode(s.Spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing OCI specification: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
|
||||
// retrieves the value of the environment variable named
|
||||
// by the key. If the variable is present in the environment the
|
||||
// value (which may be empty) is returned and the boolean is true.
|
||||
// Otherwise the returned value will be empty and the boolean will
|
||||
// be false.
|
||||
func (s fileSpec) LookupEnv(key string) (string, bool) {
|
||||
if s.Spec == nil || s.Spec.Process == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, env := range s.Spec.Process.Env {
|
||||
if !strings.HasPrefix(env, key) {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if parts[0] == key {
|
||||
if len(parts) < 2 {
|
||||
return "", true
|
||||
}
|
||||
return parts[1], true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
252
pkg/oci/spec_file_test.go
Normal file
252
pkg/oci/spec_file_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLookupEnv(t *testing.T) {
|
||||
const envName = "TEST_ENV"
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
expectedValue string
|
||||
expectedExits bool
|
||||
}{
|
||||
{
|
||||
// nil spec
|
||||
spec: nil,
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// nil process
|
||||
spec: &specs.Spec{},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// nil env
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// empty env
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// different env set
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// same prefix
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// same suffix
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// set blank
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV="}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set no-equals
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set value
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
|
||||
},
|
||||
expectedValue: "something",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set with equals
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
|
||||
},
|
||||
expectedValue: "something=somethingelse",
|
||||
expectedExits: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := fileSpec{
|
||||
Spec: tc.spec,
|
||||
}
|
||||
|
||||
value, exists := spec.LookupEnv(envName)
|
||||
|
||||
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
|
||||
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFrom(t *testing.T) {
|
||||
testCases := []struct {
|
||||
contents []byte
|
||||
isError bool
|
||||
spec *specs.Spec
|
||||
}{
|
||||
{
|
||||
contents: []byte{},
|
||||
isError: true,
|
||||
},
|
||||
{
|
||||
contents: []byte("{}"),
|
||||
isError: false,
|
||||
spec: &specs.Spec{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := fileSpec{}
|
||||
err := spec.loadFrom(bytes.NewReader(tc.contents))
|
||||
|
||||
if tc.isError {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoError(t, err, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
if tc.spec == nil {
|
||||
require.Nil(t, spec.Spec, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushTo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
isError bool
|
||||
spec *specs.Spec
|
||||
contents string
|
||||
}{
|
||||
{
|
||||
spec: nil,
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
contents: "{\"ociVersion\":\"\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
buffer := bytes.Buffer{}
|
||||
|
||||
spec := fileSpec{Spec: tc.spec}
|
||||
err := spec.flushTo(&buffer)
|
||||
|
||||
if tc.isError {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoError(t, err, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
require.EqualValues(t, tc.contents, buffer.String(), "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
// Add a simple test for a writer that returns an error when writing
|
||||
spec := fileSpec{Spec: &specs.Spec{}}
|
||||
err := spec.flushTo(errorWriter{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestModify(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
modifierError error
|
||||
}{
|
||||
{
|
||||
spec: nil,
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
modifierError: fmt.Errorf("error in modifier"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := fileSpec{Spec: tc.spec}
|
||||
|
||||
modifier := func(spec *specs.Spec) error {
|
||||
if tc.modifierError == nil {
|
||||
spec.Version = "updated"
|
||||
}
|
||||
return tc.modifierError
|
||||
}
|
||||
|
||||
err := spec.Modify(modifier)
|
||||
|
||||
if tc.spec == nil {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else if tc.modifierError != nil {
|
||||
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
|
||||
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoError(t, err, "%d: %v", i, tc)
|
||||
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// errorWriter implements the io.Writer interface, always returning an error when
|
||||
// writing.
|
||||
type errorWriter struct{}
|
||||
|
||||
func (e errorWriter) Write([]byte) (int, error) {
|
||||
return 0, fmt.Errorf("error writing")
|
||||
}
|
||||
201
pkg/oci/spec_mock.go
Normal file
201
pkg/oci/spec_mock.go
Normal file
@@ -0,0 +1,201 @@
|
||||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Ensure, that SpecMock does implement Spec.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ Spec = &SpecMock{}
|
||||
|
||||
// SpecMock is a mock implementation of Spec.
|
||||
//
|
||||
// func TestSomethingThatUsesSpec(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked Spec
|
||||
// mockedSpec := &SpecMock{
|
||||
// FlushFunc: func() error {
|
||||
// panic("mock out the Flush method")
|
||||
// },
|
||||
// LoadFunc: func() error {
|
||||
// panic("mock out the Load method")
|
||||
// },
|
||||
// LookupEnvFunc: func(s string) (string, bool) {
|
||||
// panic("mock out the LookupEnv method")
|
||||
// },
|
||||
// ModifyFunc: func(specModifier SpecModifier) error {
|
||||
// panic("mock out the Modify method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedSpec in code that requires Spec
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type SpecMock struct {
|
||||
// FlushFunc mocks the Flush method.
|
||||
FlushFunc func() error
|
||||
|
||||
// LoadFunc mocks the Load method.
|
||||
LoadFunc func() error
|
||||
|
||||
// LookupEnvFunc mocks the LookupEnv method.
|
||||
LookupEnvFunc func(s string) (string, bool)
|
||||
|
||||
// ModifyFunc mocks the Modify method.
|
||||
ModifyFunc func(specModifier SpecModifier) error
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// Flush holds details about calls to the Flush method.
|
||||
Flush []struct {
|
||||
}
|
||||
// Load holds details about calls to the Load method.
|
||||
Load []struct {
|
||||
}
|
||||
// LookupEnv holds details about calls to the LookupEnv method.
|
||||
LookupEnv []struct {
|
||||
// S is the s argument value.
|
||||
S string
|
||||
}
|
||||
// Modify holds details about calls to the Modify method.
|
||||
Modify []struct {
|
||||
// SpecModifier is the specModifier argument value.
|
||||
SpecModifier SpecModifier
|
||||
}
|
||||
}
|
||||
lockFlush sync.RWMutex
|
||||
lockLoad sync.RWMutex
|
||||
lockLookupEnv sync.RWMutex
|
||||
lockModify sync.RWMutex
|
||||
}
|
||||
|
||||
// Flush calls FlushFunc.
|
||||
func (mock *SpecMock) Flush() error {
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockFlush.Lock()
|
||||
mock.calls.Flush = append(mock.calls.Flush, callInfo)
|
||||
mock.lockFlush.Unlock()
|
||||
if mock.FlushFunc == nil {
|
||||
var (
|
||||
errOut error
|
||||
)
|
||||
return errOut
|
||||
}
|
||||
return mock.FlushFunc()
|
||||
}
|
||||
|
||||
// FlushCalls gets all the calls that were made to Flush.
|
||||
// Check the length with:
|
||||
// len(mockedSpec.FlushCalls())
|
||||
func (mock *SpecMock) FlushCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockFlush.RLock()
|
||||
calls = mock.calls.Flush
|
||||
mock.lockFlush.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// Load calls LoadFunc.
|
||||
func (mock *SpecMock) Load() error {
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockLoad.Lock()
|
||||
mock.calls.Load = append(mock.calls.Load, callInfo)
|
||||
mock.lockLoad.Unlock()
|
||||
if mock.LoadFunc == nil {
|
||||
var (
|
||||
errOut error
|
||||
)
|
||||
return errOut
|
||||
}
|
||||
return mock.LoadFunc()
|
||||
}
|
||||
|
||||
// LoadCalls gets all the calls that were made to Load.
|
||||
// Check the length with:
|
||||
// len(mockedSpec.LoadCalls())
|
||||
func (mock *SpecMock) LoadCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockLoad.RLock()
|
||||
calls = mock.calls.Load
|
||||
mock.lockLoad.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// LookupEnv calls LookupEnvFunc.
|
||||
func (mock *SpecMock) LookupEnv(s string) (string, bool) {
|
||||
callInfo := struct {
|
||||
S string
|
||||
}{
|
||||
S: s,
|
||||
}
|
||||
mock.lockLookupEnv.Lock()
|
||||
mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo)
|
||||
mock.lockLookupEnv.Unlock()
|
||||
if mock.LookupEnvFunc == nil {
|
||||
var (
|
||||
sOut string
|
||||
bOut bool
|
||||
)
|
||||
return sOut, bOut
|
||||
}
|
||||
return mock.LookupEnvFunc(s)
|
||||
}
|
||||
|
||||
// LookupEnvCalls gets all the calls that were made to LookupEnv.
|
||||
// Check the length with:
|
||||
// len(mockedSpec.LookupEnvCalls())
|
||||
func (mock *SpecMock) LookupEnvCalls() []struct {
|
||||
S string
|
||||
} {
|
||||
var calls []struct {
|
||||
S string
|
||||
}
|
||||
mock.lockLookupEnv.RLock()
|
||||
calls = mock.calls.LookupEnv
|
||||
mock.lockLookupEnv.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// Modify calls ModifyFunc.
|
||||
func (mock *SpecMock) Modify(specModifier SpecModifier) error {
|
||||
callInfo := struct {
|
||||
SpecModifier SpecModifier
|
||||
}{
|
||||
SpecModifier: specModifier,
|
||||
}
|
||||
mock.lockModify.Lock()
|
||||
mock.calls.Modify = append(mock.calls.Modify, callInfo)
|
||||
mock.lockModify.Unlock()
|
||||
if mock.ModifyFunc == nil {
|
||||
var (
|
||||
errOut error
|
||||
)
|
||||
return errOut
|
||||
}
|
||||
return mock.ModifyFunc(specModifier)
|
||||
}
|
||||
|
||||
// ModifyCalls gets all the calls that were made to Modify.
|
||||
// Check the length with:
|
||||
// len(mockedSpec.ModifyCalls())
|
||||
func (mock *SpecMock) ModifyCalls() []struct {
|
||||
SpecModifier SpecModifier
|
||||
} {
|
||||
var calls []struct {
|
||||
SpecModifier SpecModifier
|
||||
}
|
||||
mock.lockModify.RLock()
|
||||
calls = mock.calls.Modify
|
||||
mock.lockModify.RUnlock()
|
||||
return calls
|
||||
}
|
||||
82
pkg/runtime/runtime_modifier.go
Normal file
82
pkg/runtime/runtime_modifier.go
Normal file
@@ -0,0 +1,82 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/modify"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
)
|
||||
|
||||
type modifyingRuntimeWrapper struct {
|
||||
logger *log.Logger
|
||||
runtime oci.Runtime
|
||||
ociSpec oci.Spec
|
||||
modifier modify.Modifier
|
||||
}
|
||||
|
||||
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
|
||||
|
||||
// NewModifyingRuntimeWrapperWithLogger creates a runtime wrapper that applies the specified modifier to the OCI specification
|
||||
// before invoking the wrapped runtime.
|
||||
func NewModifyingRuntimeWrapperWithLogger(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier modify.Modifier) oci.Runtime {
|
||||
rt := modifyingRuntimeWrapper{
|
||||
logger: logger,
|
||||
runtime: runtime,
|
||||
ociSpec: spec,
|
||||
modifier: modifier,
|
||||
}
|
||||
return &rt
|
||||
}
|
||||
|
||||
// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing
|
||||
// into the wrapped runtime.
|
||||
func (r *modifyingRuntimeWrapper) Exec(args []string) error {
|
||||
if oci.HasCreateSubcommand(args) {
|
||||
err := r.modify()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not apply required modification to OCI specification: %v", err)
|
||||
}
|
||||
r.logger.Infof("Applied required modification to OCI specification")
|
||||
} else {
|
||||
r.logger.Infof("No modification of OCI specification required")
|
||||
}
|
||||
|
||||
r.logger.Infof("Forwarding command to runtime")
|
||||
return r.runtime.Exec(args)
|
||||
}
|
||||
|
||||
// modify loads, modifies, and flushes the OCI specification using the defined Modifier
|
||||
func (r *modifyingRuntimeWrapper) modify() error {
|
||||
err := r.ociSpec.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading OCI specification for modification: %v", err)
|
||||
}
|
||||
|
||||
err = r.modifier.Modify(r.ociSpec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying OCI spec: %v", err)
|
||||
}
|
||||
|
||||
err = r.ociSpec.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing modified OCI specification: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
160
pkg/runtime/runtime_modifier_test.go
Normal file
160
pkg/runtime/runtime_modifier_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/modify"
|
||||
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
|
||||
)
|
||||
|
||||
func TestRuntimeModifier(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
testCases := []struct {
|
||||
args []string
|
||||
shouldModify bool
|
||||
}{
|
||||
{},
|
||||
{
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
runtimeMock := &oci.RuntimeMock{}
|
||||
specMock := &oci.SpecMock{}
|
||||
modifierMock := &modify.ModifierMock{}
|
||||
|
||||
r := NewModifyingRuntimeWrapperWithLogger(
|
||||
logger,
|
||||
runtimeMock,
|
||||
specMock,
|
||||
modifierMock,
|
||||
)
|
||||
|
||||
err := r.Exec(tc.args)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedCalls := 0
|
||||
if tc.shouldModify {
|
||||
expectedCalls = 1
|
||||
}
|
||||
|
||||
require.Len(t, specMock.LoadCalls(), expectedCalls)
|
||||
require.Len(t, modifierMock.ModifyCalls(), expectedCalls)
|
||||
require.Len(t, specMock.FlushCalls(), expectedCalls)
|
||||
|
||||
require.Len(t, runtimeMock.ExecCalls(), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeModiferWithLoadError(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
runtimeMock := &oci.RuntimeMock{}
|
||||
specMock := &oci.SpecMock{
|
||||
LoadFunc: specErrorFunc,
|
||||
}
|
||||
modifierMock := &modify.ModifierMock{}
|
||||
|
||||
r := NewModifyingRuntimeWrapperWithLogger(
|
||||
logger,
|
||||
runtimeMock,
|
||||
specMock,
|
||||
modifierMock,
|
||||
)
|
||||
|
||||
err := r.Exec([]string{"create"})
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
require.Len(t, specMock.LoadCalls(), 1)
|
||||
require.Len(t, modifierMock.ModifyCalls(), 0)
|
||||
require.Len(t, specMock.FlushCalls(), 0)
|
||||
|
||||
require.Len(t, runtimeMock.ExecCalls(), 0)
|
||||
}
|
||||
|
||||
func TestRuntimeModiferWithFlushError(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
runtimeMock := &oci.RuntimeMock{}
|
||||
specMock := &oci.SpecMock{
|
||||
FlushFunc: specErrorFunc,
|
||||
}
|
||||
modifierMock := &modify.ModifierMock{}
|
||||
|
||||
r := NewModifyingRuntimeWrapperWithLogger(
|
||||
logger,
|
||||
runtimeMock,
|
||||
specMock,
|
||||
modifierMock,
|
||||
)
|
||||
|
||||
err := r.Exec([]string{"create"})
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
require.Len(t, specMock.LoadCalls(), 1)
|
||||
require.Len(t, modifierMock.ModifyCalls(), 1)
|
||||
require.Len(t, specMock.FlushCalls(), 1)
|
||||
|
||||
require.Len(t, runtimeMock.ExecCalls(), 0)
|
||||
}
|
||||
|
||||
func TestRuntimeModiferWithModifyError(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
runtimeMock := &oci.RuntimeMock{}
|
||||
specMock := &oci.SpecMock{}
|
||||
modifierMock := &modify.ModifierMock{
|
||||
ModifyFunc: modifierErrorFunc,
|
||||
}
|
||||
|
||||
r := NewModifyingRuntimeWrapperWithLogger(
|
||||
logger,
|
||||
runtimeMock,
|
||||
specMock,
|
||||
modifierMock,
|
||||
)
|
||||
|
||||
err := r.Exec([]string{"create"})
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
require.Len(t, specMock.LoadCalls(), 1)
|
||||
require.Len(t, modifierMock.ModifyCalls(), 1)
|
||||
require.Len(t, specMock.FlushCalls(), 0)
|
||||
|
||||
require.Len(t, runtimeMock.ExecCalls(), 0)
|
||||
|
||||
}
|
||||
|
||||
func specErrorFunc() error {
|
||||
return fmt.Errorf("error")
|
||||
}
|
||||
|
||||
func modifierErrorFunc(oci.Spec) error {
|
||||
return fmt.Errorf("error")
|
||||
}
|
||||
Reference in New Issue
Block a user