Compare commits

...

42 Commits

Author SHA1 Message Date
Evan Lezar
206ac20e78 Merge branch 'allow-centos7-aarch64-scan-failure' into 'release-1.13'
Allow failure for centos7-aarch64 scans

See merge request nvidia/container-toolkit/container-toolkit!441
2023-07-12 19:55:58 +00:00
Evan Lezar
c3c23d647f Allow failure for centos7-aarch64 scans
For the release-1.13 branch, we don't build aarch64 images for centos7.

This means that, depending on the docker version, a docker pull fails if
the platform is specified.

As a simple workaround, we allow failure for this scan step.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-07-12 21:54:14 +02:00
Evan Lezar
f2a42dc924 Merge branch 'bump-cuda-12.2.0' into 'release-1.13'
Merge branch 'bump-cuda-12.2.0' into 'main'

See merge request nvidia/container-toolkit/container-toolkit!440
2023-07-11 19:18:28 +00:00
Evan Lezar
752afe8ca9 Bump version to v1.13.4
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-07-11 20:34:38 +02:00
Evan Lezar
78940d0a95 Bump libnvidia-container to v1.13.4
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-07-11 20:34:38 +02:00
Evan Lezar
4e0d6f3934 Merge branch 'bump-cuda-12.2.0' into 'main'
Bump cuda base image to 12.2.0

See merge request nvidia/container-toolkit/container-toolkit!435
2023-07-11 20:34:36 +02:00
Evan Lezar
c5a93b8d70 Merge branch 'cherry-pick-wsl2' into 'release-1.13'
Backport CDI fixes for 1.13.3 release

See merge request nvidia/container-toolkit/container-toolkit!429
2023-06-27 18:40:19 +00:00
Evan Lezar
cc06766f25 Merge branch 'fix-load-kernel-modules' into 'main'
Split internal system package

See merge request nvidia/container-toolkit/container-toolkit!420
2023-06-27 17:36:57 +02:00
Evan Lezar
7c807c2c22 Merge branch 'CNT-4302/cdi-only' into 'main'
Skip additional modifications in CDI mode

See merge request nvidia/container-toolkit/container-toolkit!413
2023-06-27 17:08:14 +02:00
Evan Lezar
89781ad6a3 Merge branch 'use-major-minor-for-cuda-version' into 'main'
Use *.* pattern when locating libcuda.so

See merge request nvidia/container-toolkit/container-toolkit!397
2023-06-27 16:59:35 +02:00
Evan Lezar
f677245d60 Merge branch 'fix-multiple-driver-roots-wsl' into 'main'
Fix bug with multiple driver store paths

See merge request nvidia/container-toolkit/container-toolkit!425
2023-06-27 16:59:33 +02:00
Evan Lezar
9d31bd4cc3 Merge branch 'fix-cdi-permissions' into 'main'
Properly set spec permissions

See merge request nvidia/container-toolkit/container-toolkit!383
2023-06-27 16:27:28 +02:00
Carlos Eduardo Arango Gutierrez
b063fa40b1 Merge branch 'fix-cdi-spec-permissions' into 'main'
Generate CDI specifications with 644 permissions to allow non-root clients to consume them

See merge request nvidia/container-toolkit/container-toolkit!381
2023-06-27 16:27:27 +02:00
Evan Lezar
9b7904e0bb Merge branch 'CNT-4075' into 'release-1.13'
Allow same envars for all runtime configs

See merge request nvidia/container-toolkit/container-toolkit!419
2023-06-27 14:26:29 +00:00
Evan Lezar
a9ccef6090 Ensure common envvars have higher precedence
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:26 +02:00
Carlos Eduardo Arango Gutierrez
a4e13c5197 Add entry to changelog
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
2023-06-27 16:25:24 +02:00
Evan Lezar
c9a8b7f335 Ensure runtime dir is set for crio cleanup
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
591e610905 Remove unused constants and variables
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
524802df2b Rework restart logic
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
19ca4338f2 Add version info to config CLIs
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
56c533d7d4 Refactor toolking to setup and cleanup configs
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
b9b19494d0 Add runtimeDir as argument
Thsi change adds the --nvidia-runtime-dir as a command line
argument when configuring container runtimes in the toolkit container.
This removes the need to set it via the command line.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
47b6b01f48 Allow same envars for all runtime configs
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
7a70850679 Merge branch 'bump-version-v1.13.3' into 'release-1.13'
Bump version to v1.13.3

See merge request nvidia/container-toolkit/container-toolkit!428
2023-06-27 14:14:02 +00:00
Evan Lezar
6052b3eba3 Update libnvidia-container
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 15:30:46 +02:00
Evan Lezar
4ae775d683 Bump version to 1.13.3
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 15:29:19 +02:00
Evan Lezar
d9cf2682f9 Merge branch 'remove-libnvidia-container-workaround' into 'release-1.13'
Merge branch 'revert-kitmaker-workaround' into 'main'

See merge request nvidia/container-toolkit/container-toolkit!417
2023-06-06 20:19:00 +00:00
Evan Lezar
1e5a6e1fa3 Merge branch 'revert-kitmaker-workaround' into 'main'
Remove workaround to add libnvidia-container0 to kitmaker archive

See merge request nvidia/container-toolkit/container-toolkit!378
2023-06-06 22:17:30 +02:00
Evan Lezar
5b81e30704 Merge branch 'cherry-pick-for-v1.13.2' into 'release-1.13'
Cherry pick changes for 1.13.2

See merge request nvidia/container-toolkit/container-toolkit!407
2023-06-06 19:03:42 +00:00
Evan Lezar
a34b08908e Update libnvidia-container
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-06 20:14:27 +02:00
Evan Lezar
234f7ebd5f Merge branch 'bump-cuda-base-image' into 'main'
Bump CUDA baseimage version to 12.1.1

See merge request nvidia/container-toolkit/container-toolkit!412
2023-06-01 14:49:08 +02:00
Evan Lezar
36d1b7d2a5 Merge branch 'treat-log-errors-as-non-fatal' into 'main'
Ignore errors when creating debug log file

See merge request nvidia/container-toolkit/container-toolkit!404
2023-06-01 14:48:33 +02:00
Evan Lezar
b776bf712e Merge branch 'add-mod-probe' into 'main'
Add option to load NVIDIA kernel modules

See merge request nvidia/container-toolkit/container-toolkit!409
2023-06-01 14:48:33 +02:00
Evan Lezar
90cbe938c3 Update CHANGELOG for cherry-pick
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 14:48:33 +02:00
Evan Lezar
8697267e6b Merge branch 'fix-device-symlinks' into 'main'
Fix creation of device symlinks in /dev/char

See merge request nvidia/container-toolkit/container-toolkit!399
2023-06-01 14:48:33 +02:00
Evan Lezar
2d7bb636b9 Merge branch 'CNT-4285/add-runtime-hook-path' into 'main'
Add nvidia-contianer-runtime-hook.path config option

See merge request nvidia/container-toolkit/container-toolkit!401
2023-06-01 14:48:33 +02:00
Evan Lezar
7386f86904 Merge branch 'bump-runc' into 'main'
Bump golang version and update dependencies

See merge request nvidia/container-toolkit/container-toolkit!377
2023-06-01 14:47:53 +02:00
Evan Lezar
2bcb2d633d Bump version to 1.13.2
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 08:26:59 +02:00
Evan Lezar
fac0697c93 Update libnvidia-container
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 08:26:55 +02:00
Evan Lezar
595692b9bd Set libnvidia-container branch to release-1.13
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 08:26:06 +02:00
Evan Lezar
f13d4402bd Merge branch 'update-release-1.13' into 'release-1.13'
Skip updating of components on release-1.13 branch

See merge request nvidia/container-toolkit/container-toolkit!405
2023-05-26 09:09:46 +00:00
Evan Lezar
968dce5a70 Skip updating of components
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-05-24 10:45:26 +02:00
268 changed files with 19826 additions and 3620 deletions

2
.gitmodules vendored
View File

@@ -1,7 +1,7 @@
[submodule "third_party/libnvidia-container"]
path = third_party/libnvidia-container
url = https://gitlab.com/nvidia/container-toolkit/libnvidia-container.git
branch = main
branch = release-1.13
[submodule "third_party/nvidia-container-runtime"]
path = third_party/nvidia-container-runtime
url = https://gitlab.com/nvidia/container-toolkit/container-runtime.git

View File

@@ -146,6 +146,7 @@ scan-centos7-arm64:
needs:
- image-centos7
- scan-centos7-amd64
allow_failure: true
scan-ubuntu20.04-amd64:
extends:

View File

@@ -1,5 +1,27 @@
# NVIDIA Container Toolkit Changelog
## v1.13.4
* [toolkit-container] Bump CUDA base image version to 12.2.0.
## v1.13.3
* Generate CDI specification files with `644` permissions to allow rootless applications (e.g. podman).
* Fix bug causing incorrect nvidia-smi symlink to be created on WSL2 systems with multiple driver roots.
* Fix bug when using driver versions that do not include a patch component in their version number.
* Skip additional modifications in CDI mode.
* Fix loading of kernel modules and creation of device nodes in containerized use cases.
* [toolkit-container] Allow same envars for all runtime configs
## v1.13.2
* Add `nvidia-container-runtime-hook.path` config option to specify NVIDIA Container Runtime Hook path explicitly.
* Fix bug in creation of `/dev/char` symlinks by failing operation if kernel modules are not loaded.
* Add option to load kernel modules when creating device nodes
* Add option to create device nodes when creating `/dev/char` symlinks
* Treat failures to open debug log files as non-fatal.
* Bump CUDA base image version to 12.1.1.
## v1.13.1
* Update `update-ldcache` hook to only update ldcache if it exists.

View File

@@ -172,7 +172,7 @@ func TestDuplicateHook(t *testing.T) {
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
// testing.
func addNVIDIAHook(spec *specs.Spec) error {
m := modifier.NewStableRuntimeModifier(logrus.StandardLogger())
m := modifier.NewStableRuntimeModifier(logrus.StandardLogger(), nvidiaHook)
return m.Modify(spec)
}

View File

@@ -251,6 +251,7 @@ func (m command) generateSpec(cfg *config) (spec.Interface, error) {
spec.WithDeviceSpecs(deviceSpecs),
spec.WithEdits(*commonEdits.ContainerEdits),
spec.WithFormat(cfg.format),
spec.WithPermissions(0644),
)
}

View File

@@ -28,29 +28,40 @@ import (
type allPossible struct {
logger *logrus.Logger
driverRoot string
devRoot string
deviceMajors devices.Devices
migCaps nvcaps.MigCaps
}
// newAllPossible returns a new allPossible device node lister.
// This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices.
func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error) {
func newAllPossible(logger *logrus.Logger, devRoot string) (nodeLister, error) {
deviceMajors, err := devices.GetNVIDIADevices()
if err != nil {
return nil, fmt.Errorf("failed reading device majors: %v", err)
}
var requiredMajors []devices.Name
migCaps, err := nvcaps.NewMigCaps()
if err != nil {
return nil, fmt.Errorf("failed to read MIG caps: %v", err)
}
if migCaps == nil {
migCaps = make(nvcaps.MigCaps)
} else {
requiredMajors = append(requiredMajors, devices.NVIDIACaps)
}
requiredMajors = append(requiredMajors, devices.NVIDIAGPU, devices.NVIDIAUVM)
for _, name := range requiredMajors {
if !deviceMajors.Exists(name) {
return nil, fmt.Errorf("missing required device major %s", name)
}
}
l := allPossible{
logger: logger,
driverRoot: driverRoot,
devRoot: devRoot,
deviceMajors: deviceMajors,
migCaps: migCaps,
}
@@ -61,7 +72,7 @@ func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error
// DeviceNodes returns a list of all possible device nodes for NVIDIA GPUs, control devices, and capability devices.
func (m allPossible) DeviceNodes() ([]deviceNode, error) {
gpus, err := nvpci.NewFrom(
filepath.Join(m.driverRoot, nvpci.PCIDevicesRoot),
filepath.Join(m.devRoot, nvpci.PCIDevicesRoot),
).GetGPUs()
if err != nil {
return nil, fmt.Errorf("failed to get GPU information: %v", err)
@@ -69,7 +80,7 @@ func (m allPossible) DeviceNodes() ([]deviceNode, error) {
count := len(gpus)
if count == 0 {
m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot)
m.logger.Infof("No NVIDIA devices found in %s", m.devRoot)
return nil, nil
}
@@ -168,7 +179,7 @@ func (m allPossible) newDeviceNode(deviceName devices.Name, path string, minor i
major, _ := m.deviceMajors.Get(deviceName)
return deviceNode{
path: filepath.Join(m.driverRoot, path),
path: filepath.Join(m.devRoot, path),
major: uint32(major),
minor: uint32(minor),
}

View File

@@ -24,6 +24,8 @@ import (
"strings"
"syscall"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules"
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
@@ -38,11 +40,13 @@ type command struct {
}
type config struct {
devCharPath string
driverRoot string
dryRun bool
watch bool
createAll bool
devCharPath string
driverRoot string
dryRun bool
watch bool
createAll bool
createDeviceNodes bool
loadKernelModules bool
}
// NewCommand constructs a command sub-command with the specified logger
@@ -97,6 +101,18 @@ func (m command) build() *cli.Command {
Destination: &cfg.createAll,
EnvVars: []string{"CREATE_ALL"},
},
&cli.BoolFlag{
Name: "load-kernel-modules",
Usage: "Load the NVIDIA kernel modules before creating symlinks. This is only applicable when --create-all is set.",
Destination: &cfg.loadKernelModules,
EnvVars: []string{"LOAD_KERNEL_MODULES"},
},
&cli.BoolFlag{
Name: "create-device-nodes",
Usage: "Create the NVIDIA control device nodes in the driver root if they do not exist. This is only applicable when --create-all is set",
Destination: &cfg.createDeviceNodes,
EnvVars: []string{"CREATE_DEVICE_NODES"},
},
&cli.BoolFlag{
Name: "dry-run",
Usage: "If set, the command will not create any symlinks.",
@@ -114,6 +130,16 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error {
return fmt.Errorf("create-all and watch are mutually exclusive")
}
if cfg.loadKernelModules && !cfg.createAll {
m.logger.Warn("load-kernel-modules is only applicable when create-all is set; ignoring")
cfg.loadKernelModules = false
}
if cfg.createDeviceNodes && !cfg.createAll {
m.logger.Warn("create-device-nodes is only applicable when create-all is set; ignoring")
cfg.createDeviceNodes = false
}
return nil
}
@@ -137,6 +163,8 @@ func (m command) run(c *cli.Context, cfg *config) error {
WithDriverRoot(cfg.driverRoot),
WithDryRun(cfg.dryRun),
WithCreateAll(cfg.createAll),
WithLoadKernelModules(cfg.loadKernelModules),
WithCreateDeviceNodes(cfg.createDeviceNodes),
)
if err != nil {
return fmt.Errorf("failed to create symlink creator: %v", err)
@@ -186,12 +214,15 @@ create:
}
type linkCreator struct {
logger *logrus.Logger
lister nodeLister
driverRoot string
devCharPath string
dryRun bool
createAll bool
logger *logrus.Logger
lister nodeLister
driverRoot string
devRoot string
devCharPath string
dryRun bool
createAll bool
createDeviceNodes bool
loadKernelModules bool
}
// Creator is an interface for creating symlinks to /dev/nv* devices in /dev/char.
@@ -214,29 +245,76 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) {
if c.driverRoot == "" {
c.driverRoot = "/"
}
if c.devRoot == "" {
c.devRoot = "/"
}
if c.devCharPath == "" {
c.devCharPath = defaultDevCharPath
}
if err := c.setup(); err != nil {
return nil, err
}
if c.createAll {
lister, err := newAllPossible(c.logger, c.driverRoot)
lister, err := newAllPossible(c.logger, c.devRoot)
if err != nil {
return nil, fmt.Errorf("failed to create all possible device lister: %v", err)
}
c.lister = lister
} else {
c.lister = existing{c.logger, c.driverRoot}
c.lister = existing{c.logger, c.devRoot}
}
return c, nil
}
func (m linkCreator) setup() error {
if !m.loadKernelModules && !m.createDeviceNodes {
return nil
}
if m.loadKernelModules {
modules := nvmodules.New(
nvmodules.WithLogger(m.logger),
nvmodules.WithDryRun(m.dryRun),
nvmodules.WithRoot(m.driverRoot),
)
if err := modules.LoadAll(); err != nil {
return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err)
}
}
if m.createDeviceNodes {
devices, err := nvdevices.New(
nvdevices.WithLogger(m.logger),
nvdevices.WithDryRun(m.dryRun),
nvdevices.WithDevRoot(m.devRoot),
)
if err != nil {
return err
}
if err := devices.CreateNVIDIAControlDevices(); err != nil {
return fmt.Errorf("failed to create NVIDIA device nodes: %v", err)
}
}
return nil
}
// WithDriverRoot sets the driver root path.
// This is the path in which kernel modules must be loaded.
func WithDriverRoot(root string) Option {
return func(c *linkCreator) {
c.driverRoot = root
}
}
// WithDevRoot sets the root path for the /dev directory.
func WithDevRoot(root string) Option {
return func(c *linkCreator) {
c.devRoot = root
}
}
// WithDevCharPath sets the path at which the symlinks will be created.
func WithDevCharPath(path string) Option {
return func(c *linkCreator) {
@@ -265,6 +343,20 @@ func WithCreateAll(createAll bool) Option {
}
}
// WithLoadKernelModules sets the loadKernelModules flag for the linkCreator.
func WithLoadKernelModules(loadKernelModules bool) Option {
return func(lc *linkCreator) {
lc.loadKernelModules = loadKernelModules
}
}
// WithCreateDeviceNodes sets the createDeviceNodes flag for the linkCreator.
func WithCreateDeviceNodes(createDeviceNodes bool) Option {
return func(lc *linkCreator) {
lc.createDeviceNodes = createDeviceNodes
}
}
// CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root.
func (m linkCreator) CreateLinks() error {
deviceNodes, err := m.lister.DeviceNodes()

View File

@@ -30,8 +30,8 @@ type nodeLister interface {
}
type existing struct {
logger *logrus.Logger
driverRoot string
logger *logrus.Logger
devRoot string
}
// DeviceNodes returns a list of NVIDIA device nodes in the specified root.
@@ -39,7 +39,7 @@ type existing struct {
func (m existing) DeviceNodes() ([]deviceNode, error) {
locator := lookup.NewCharDeviceLocator(
lookup.WithLogger(m.logger),
lookup.WithRoot(m.driverRoot),
lookup.WithRoot(m.devRoot),
lookup.WithOptional(true),
)
@@ -54,7 +54,7 @@ func (m existing) DeviceNodes() ([]deviceNode, error) {
}
if len(devices) == 0 && len(capDevices) == 0 {
m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot)
m.logger.Infof("No NVIDIA devices found in %s", m.devRoot)
return nil, nil
}

View File

@@ -19,7 +19,8 @@ package createdevicenodes
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
)
@@ -34,6 +35,8 @@ type options struct {
dryRun bool
control bool
loadKernelModules bool
}
// NewCommand constructs a command sub-command with the specified logger
@@ -72,6 +75,11 @@ func (m command) build() *cli.Command {
Usage: "create all control device nodes: nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools",
Destination: &opts.control,
},
&cli.BoolFlag{
Name: "load-kernel-modules",
Usage: "load the NVIDIA Kernel Modules before creating devices nodes",
Destination: &opts.loadKernelModules,
},
&cli.BoolFlag{
Name: "dry-run",
Usage: "if set, the command will not create any symlinks.",
@@ -89,18 +97,29 @@ func (m command) validateFlags(r *cli.Context, opts *options) error {
}
func (m command) run(c *cli.Context, opts *options) error {
s, err := system.New(
system.WithLogger(m.logger),
system.WithDryRun(opts.dryRun),
)
if err != nil {
return fmt.Errorf("failed to create library: %v", err)
if opts.loadKernelModules {
modules := nvmodules.New(
nvmodules.WithLogger(m.logger),
nvmodules.WithDryRun(opts.dryRun),
nvmodules.WithRoot(opts.driverRoot),
)
if err := modules.LoadAll(); err != nil {
return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err)
}
}
if opts.control {
devices, err := nvdevices.New(
nvdevices.WithLogger(m.logger),
nvdevices.WithDryRun(opts.dryRun),
nvdevices.WithDevRoot(opts.driverRoot),
)
if err != nil {
return err
}
m.logger.Infof("Creating control device nodes at %s", opts.driverRoot)
if err := s.CreateNVIDIAControlDeviceNodesAt(opts.driverRoot); err != nil {
return fmt.Errorf("failed to create control device nodes: %v", err)
if err := devices.CreateNVIDIAControlDevices(); err != nil {
return fmt.Errorf("failed to create NVIDIA control device nodes: %v", err)
}
}
return nil

20
go.mod
View File

@@ -1,31 +1,30 @@
module github.com/NVIDIA/nvidia-container-toolkit
go 1.18
go 1.20
require (
github.com/BurntSushi/toml v1.0.0
github.com/BurntSushi/toml v1.2.1
github.com/NVIDIA/go-nvml v0.12.0-0
github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a
github.com/fsnotify/fsnotify v1.5.4
github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb
github.com/opencontainers/runtime-spec v1.1.0-rc.2
github.com/pelletier/go-toml v1.9.4
github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.7.0
github.com/stretchr/testify v1.8.1
github.com/urfave/cli/v2 v2.3.0
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438
golang.org/x/mod v0.5.0
golang.org/x/sys v0.0.0-20220927170352-d9d178bc13c6
sigs.k8s.io/yaml v1.3.0
golang.org/x/sys v0.7.0
)
require (
github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/opencontainers/runc v1.1.4 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/opencontainers/runc v1.1.6 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 // indirect
github.com/opencontainers/selinux v1.10.1 // indirect
github.com/opencontainers/selinux v1.11.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 // indirect
@@ -33,4 +32,5 @@ require (
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
)

72
go.sum
View File

@@ -1,86 +1,76 @@
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.0.0 h1:dtDWrepsVPfW9H/4y7dDgFc2MBUSeJhlaDtK13CxFlU=
github.com/BurntSushi/toml v1.0.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/NVIDIA/go-nvml v0.11.6-0.0.20220823120812-7e2082095e82 h1:x751Xx1tdxkiA/sdkv2J769n21UbYKzVOpe9S/h1M3k=
github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak=
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/NVIDIA/go-nvml v0.11.6-0.0.20220823120812-7e2082095e82/go.mod h1:hy7HYeQy335x6nEss0Ne3PYqleRa6Ct+VKD9RQ4nyFs=
github.com/NVIDIA/go-nvml v0.12.0-0 h1:eHYNHbzAsMgWYshf6dEmTY66/GCXnORJFnzm3TNH4mc=
github.com/NVIDIA/go-nvml v0.12.0-0/go.mod h1:hy7HYeQy335x6nEss0Ne3PYqleRa6Ct+VKD9RQ4nyFs=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
github.com/checkpoint-restore/go-criu/v5 v5.3.0/go.mod h1:E/eQpaFtUKGOOSEBZgmKAcn+zUUwWxqcaKZlF54wK8E=
github.com/cilium/ebpf v0.7.0/go.mod h1:/oI2+1shJiTGAMgl6/RgJr36Eo1jzrRcAWbcXO2usCA=
github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a h1:sP3PcgyIkRlHqfF3Jfpe/7G8kf/qpzG4C8r94y9hLbE=
github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a/go.mod h1:xMRa4fJgXzSDFUCURSimOUgoSc+odohvO3uXT9xjqH0=
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.1 h1:r/myEWzV9lfsM1tFLgDyu0atFtJ1fXn261LKYj/3DxU=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cyphar/filepath-securejoin v0.2.3/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k=
github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI=
github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.0.6/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mndrix/tap-go v0.0.0-20171203230836-629fa407e90b/go.mod h1:pzzDgJWZ34fGzaAZGFW22KVZDfyrYW+QABMrWnJBnSs=
github.com/moby/sys/mountinfo v0.5.0/go.mod h1:3bMD3Rg+zkqx8MRYPi7Pyb0Ie97QEBmdxbhnCLlSvSU=
github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ=
github.com/opencontainers/runc v1.1.4 h1:nRCz/8sKg6K6jgYAFLDlXzPeITBZJyX28DBVhWD+5dg=
github.com/opencontainers/runc v1.1.4/go.mod h1:1J5XiS+vdZ3wCyZybsuxXZWGrgSr8fFJHLXuG2PsnNg=
github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb h1:1xSVPOd7/UA+39/hXEGnBJ13p6JFB0E1EvQFlrRDOXI=
github.com/opencontainers/runc v1.1.6 h1:XbhB8IfG/EsnhNvZtNdLB0GBw92GYEFvKlhaJk9jUgA=
github.com/opencontainers/runc v1.1.6/go.mod h1:CbUumNnWCuTGFukNXahoo/RFBZvDAgRh/smNYNOhA50=
github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.1.0-rc.2 h1:ucBtEms2tamYYW/SvGpvq9yUN0NEVL6oyLEwDcTSrk8=
github.com/opencontainers/runtime-spec v1.1.0-rc.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 h1:DmNGcqH3WDbV5k8OJ+esPWbqUOX5rMLR2PMvziDMJi0=
github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626/go.mod h1:BRHJJd0E+cx42OybVYSgUvZmU0B8P9gZuRXlZUP7TKI=
github.com/opencontainers/selinux v1.9.1/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI=
github.com/opencontainers/selinux v1.10.0/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI=
github.com/opencontainers/selinux v1.10.1 h1:09LIPVRP3uuZGQvgR+SgMSNBd1Eb3vlRbGqQpoHsF8w=
github.com/opencontainers/selinux v1.10.1/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI=
github.com/opencontainers/selinux v1.11.0 h1:+5Zbo97w3Lbmb3PeqQtpmTkMwsW5nRI3YaLpt7tQ7oU=
github.com/opencontainers/selinux v1.11.0/go.mod h1:E5dMC3VPuVvVHDYmi78qvhJp8+M586T4DlDRYpFkyec=
github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/seccomp/libseccomp-golang v0.9.2-0.20220502022130-f33da4d89646/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 h1:kdXcSzyDtseVEc4yCz2qF8ZrQvIDBJLl4S1c3GCXmoI=
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
github.com/urfave/cli v1.19.1/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
@@ -88,35 +78,19 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230119114711-6fe07bb33342 h1:083n9fJt2dWOpJd/X/q9Xgl5XtQLL22uSFYbzVqJssg=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230119114711-6fe07bb33342/go.mod h1:GStidGxhaqJhYFW1YpOnLvYCbL2EsM0od7IW4u7+JgU=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438 h1:+qRai7XRl8omFQVCeHcaWzL542Yw64vfmuXG+79ZCIc=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438/go.mod h1:GStidGxhaqJhYFW1YpOnLvYCbL2EsM0od7IW4u7+JgU=
golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q=
golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220927170352-d9d178bc13c6 h1:cy1ko5847T/lJ45eyg/7uLprIE/amW5IXxGtEnQdYMI=
golang.org/x/sys v0.0.0-20220927170352-d9d178bc13c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

View File

@@ -21,13 +21,19 @@ import (
"io"
"os"
"path"
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/pelletier/go-toml"
"github.com/sirupsen/logrus"
)
const (
configOverride = "XDG_CONFIG_HOME"
configFilePath = "nvidia-container-runtime/config.toml"
nvidiaContainerRuntimeHookExecutable = "nvidia-container-runtime-hook"
nvidiaContainerRuntimeHookDefaultPath = "/usr/bin/nvidia-container-runtime-hook"
)
var (
@@ -124,3 +130,41 @@ func getDefaultConfig() *Config {
return &c
}
// ResolveNVIDIAContainerRuntimeHookPath resolves the path the nvidia-container-runtime-hook binary.
func ResolveNVIDIAContainerRuntimeHookPath(logger *logrus.Logger, nvidiaContainerRuntimeHookPath string) string {
return resolveWithDefault(
logger,
"NVIDIA Container Runtime Hook",
nvidiaContainerRuntimeHookPath,
nvidiaContainerRuntimeHookDefaultPath,
)
}
// resolveWithDefault resolves the path to the specified binary.
// If an absolute path is specified, it is used directly without searching for the binary.
// If the binary cannot be found in the path, the specified default is used instead.
func resolveWithDefault(logger *logrus.Logger, label string, path string, defaultPath string) string {
if filepath.IsAbs(path) {
logger.Debugf("Using specified %v path %v", label, path)
return path
}
if path == "" {
path = filepath.Base(defaultPath)
}
logger.Debugf("Locating %v as %v", label, path)
lookup := lookup.NewExecutableLocator(logger, "")
resolvedPath := defaultPath
targets, err := lookup.Locate(path)
if err != nil {
logger.Warnf("Failed to locate %v: %v", path, err)
} else {
logger.Debugf("Found %v candidates: %v", path, targets)
resolvedPath = targets[0]
}
logger.Debugf("Using %v path %v", label, path)
return resolvedPath
}

View File

@@ -76,6 +76,9 @@ func TestGetConfig(t *testing.T) {
},
},
},
NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{
Path: "nvidia-container-runtime-hook",
},
NVIDIACTKConfig: CTKConfig{
Path: "nvidia-ctk",
},
@@ -95,6 +98,7 @@ func TestGetConfig(t *testing.T) {
"nvidia-container-runtime.modes.cdi.default-kind = \"example.vendor.com/device\"",
"nvidia-container-runtime.modes.cdi.annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]",
"nvidia-container-runtime.modes.csv.mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"",
"nvidia-container-runtime-hook.path = \"/foo/bar/nvidia-container-runtime-hook\"",
"nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"",
},
expectedConfig: &Config{
@@ -120,6 +124,9 @@ func TestGetConfig(t *testing.T) {
},
},
},
NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{
Path: "/foo/bar/nvidia-container-runtime-hook",
},
NVIDIACTKConfig: CTKConfig{
Path: "/foo/bar/nvidia-ctk",
},
@@ -143,6 +150,8 @@ func TestGetConfig(t *testing.T) {
"annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]",
"[nvidia-container-runtime.modes.csv]",
"mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"",
"[nvidia-container-runtime-hook]",
"path = \"/foo/bar/nvidia-container-runtime-hook\"",
"[nvidia-ctk]",
"path = \"/foo/bar/nvidia-ctk\"",
},
@@ -169,6 +178,9 @@ func TestGetConfig(t *testing.T) {
},
},
},
NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{
Path: "/foo/bar/nvidia-container-runtime-hook",
},
NVIDIACTKConfig: CTKConfig{
Path: "/foo/bar/nvidia-ctk",
},

View File

@@ -24,6 +24,9 @@ import (
// RuntimeHookConfig stores the config options for the NVIDIA Container Runtime
type RuntimeHookConfig struct {
// Path specifies the path to the NVIDIA Container Runtime hook binary.
// If an executable name is specified, this will be resolved in the path.
Path string `toml:"path"`
// SkipModeDetection disables the mode check for the runtime hook.
SkipModeDetection bool `toml:"skip-mode-detection"`
}
@@ -55,6 +58,7 @@ func getRuntimeHookConfigFrom(toml *toml.Tree) (*RuntimeHookConfig, error) {
// GetDefaultRuntimeHookConfig defines the default values for the config
func GetDefaultRuntimeHookConfig() *RuntimeHookConfig {
c := RuntimeHookConfig{
Path: NVIDIAContainerRuntimeHookExecutable,
SkipModeDetection: false,
}

View File

@@ -13,25 +13,25 @@ var _ Discover = &DiscoverMock{}
// DiscoverMock is a mock implementation of Discover.
//
// func TestSomethingThatUsesDiscover(t *testing.T) {
// func TestSomethingThatUsesDiscover(t *testing.T) {
//
// // make and configure a mocked Discover
// mockedDiscover := &DiscoverMock{
// DevicesFunc: func() ([]Device, error) {
// panic("mock out the Devices method")
// },
// HooksFunc: func() ([]Hook, error) {
// panic("mock out the Hooks method")
// },
// MountsFunc: func() ([]Mount, error) {
// panic("mock out the Mounts method")
// },
// }
// // make and configure a mocked Discover
// mockedDiscover := &DiscoverMock{
// DevicesFunc: func() ([]Device, error) {
// panic("mock out the Devices method")
// },
// HooksFunc: func() ([]Hook, error) {
// panic("mock out the Hooks method")
// },
// MountsFunc: func() ([]Mount, error) {
// panic("mock out the Mounts method")
// },
// }
//
// // use mockedDiscover in code that requires Discover
// // and then make assertions.
// // use mockedDiscover in code that requires Discover
// // and then make assertions.
//
// }
// }
type DiscoverMock struct {
// DevicesFunc mocks the Devices method.
DevicesFunc func() ([]Device, error)
@@ -78,7 +78,8 @@ func (mock *DiscoverMock) Devices() ([]Device, error) {
// DevicesCalls gets all the calls that were made to Devices.
// Check the length with:
// len(mockedDiscover.DevicesCalls())
//
// len(mockedDiscover.DevicesCalls())
func (mock *DiscoverMock) DevicesCalls() []struct {
} {
var calls []struct {
@@ -108,7 +109,8 @@ func (mock *DiscoverMock) Hooks() ([]Hook, error) {
// HooksCalls gets all the calls that were made to Hooks.
// Check the length with:
// len(mockedDiscover.HooksCalls())
//
// len(mockedDiscover.HooksCalls())
func (mock *DiscoverMock) HooksCalls() []struct {
} {
var calls []struct {
@@ -138,7 +140,8 @@ func (mock *DiscoverMock) Mounts() ([]Mount, error) {
// MountsCalls gets all the calls that were made to Mounts.
// Check the length with:
// len(mockedDiscover.MountsCalls())
//
// len(mockedDiscover.MountsCalls())
func (mock *DiscoverMock) MountsCalls() []struct {
} {
var calls []struct {

View File

@@ -271,7 +271,7 @@ func newXorgDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath s
libCudaPaths, err := cuda.New(
cuda.WithLogger(logger),
cuda.WithDriverRoot(driverRoot),
).Locate(".*.*.*")
).Locate(".*.*")
if err != nil {
return nil, fmt.Errorf("failed to locate libcuda.so: %v", err)
}

View File

@@ -18,6 +18,7 @@ package devices
import (
"bufio"
"errors"
"fmt"
"io"
"os"
@@ -72,7 +73,14 @@ func (d devices) Get(name Name) (Major, bool) {
// GetNVIDIADevices returns the set of NVIDIA Devices on the machine
func GetNVIDIADevices() (Devices, error) {
devicesFile, err := os.Open(procDevicesPath)
return nvidiaDevices(procDevicesPath)
}
// nvidiaDevices returns the set of NVIDIA Devices from the specified devices file.
// This is useful for testing since we may be testing on a system where `/proc/devices` does
// contain a reference to NVIDIA devices.
func nvidiaDevices(devicesPath string) (Devices, error) {
devicesFile, err := os.Open(devicesPath)
if os.IsNotExist(err) {
return nil, nil
}
@@ -81,20 +89,28 @@ func GetNVIDIADevices() (Devices, error) {
}
defer devicesFile.Close()
return nvidiaDeviceFrom(devicesFile), nil
return nvidiaDeviceFrom(devicesFile)
}
func nvidiaDeviceFrom(reader io.Reader) devices {
var errNoNvidiaDevices = errors.New("no NVIDIA devices found")
func nvidiaDeviceFrom(reader io.Reader) (devices, error) {
allDevices := devicesFrom(reader)
nvidiaDevices := make(devices)
var hasNvidiaDevices bool
for n, d := range allDevices {
if !strings.HasPrefix(string(n), nvidiaDevicePrefix) {
continue
}
nvidiaDevices[n] = d
hasNvidiaDevices = true
}
return nvidiaDevices
if !hasNvidiaDevices {
return nil, errNoNvidiaDevices
}
return nvidiaDevices, nil
}
func devicesFrom(reader io.Reader) devices {

View File

@@ -45,21 +45,23 @@ func TestNvidiaDevices(t *testing.T) {
func TestProcessDeviceFile(t *testing.T) {
testCases := []struct {
lines []string
expected devices
lines []string
expected devices
expectedError error
}{
{[]string{}, make(devices)},
{[]string{"Not a valid line:"}, make(devices)},
{[]string{"195 nvidia-frontend"}, devices{"nvidia-frontend": 195}},
{[]string{"195 nvidia-frontend", "235 nvidia-caps"}, devices{"nvidia-frontend": 195, "nvidia-caps": 235}},
{[]string{" 195 nvidia-frontend"}, devices{"nvidia-frontend": 195}},
{[]string{"Not a valid line:", "", "195 nvidia-frontend"}, devices{"nvidia-frontend": 195}},
{[]string{"195 not-nvidia-frontend"}, make(devices)},
{lines: []string{}, expectedError: errNoNvidiaDevices},
{lines: []string{"Not a valid line:"}, expectedError: errNoNvidiaDevices},
{lines: []string{"195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{lines: []string{"195 nvidia-frontend", "235 nvidia-caps"}, expected: devices{"nvidia-frontend": 195, "nvidia-caps": 235}},
{lines: []string{" 195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{lines: []string{"Not a valid line:", "", "195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{lines: []string{"195 not-nvidia-frontend"}, expectedError: errNoNvidiaDevices},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) {
contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := nvidiaDeviceFrom(contents)
d, err := nvidiaDeviceFrom(contents)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expected, d)
})

View File

@@ -13,29 +13,23 @@ var _ Locator = &LocatorMock{}
// LocatorMock is a mock implementation of Locator.
//
// func TestSomethingThatUsesLocator(t *testing.T) {
// func TestSomethingThatUsesLocator(t *testing.T) {
//
// // make and configure a mocked Locator
// mockedLocator := &LocatorMock{
// LocateFunc: func(s string) ([]string, error) {
// panic("mock out the Locate method")
// },
// RelativeFunc: func(s string) (string, error) {
// panic("mock out the Relative method")
// },
// }
// // make and configure a mocked Locator
// mockedLocator := &LocatorMock{
// LocateFunc: func(s string) ([]string, error) {
// panic("mock out the Locate method")
// },
// }
//
// // use mockedLocator in code that requires Locator
// // and then make assertions.
// // use mockedLocator in code that requires Locator
// // and then make assertions.
//
// }
// }
type LocatorMock struct {
// LocateFunc mocks the Locate method.
LocateFunc func(s string) ([]string, error)
// RelativeFunc mocks the Relative method.
RelativeFunc func(s string) (string, error)
// calls tracks calls to the methods.
calls struct {
// Locate holds details about calls to the Locate method.
@@ -43,14 +37,8 @@ type LocatorMock struct {
// S is the s argument value.
S string
}
// Relative holds details about calls to the Relative method.
Relative []struct {
// S is the s argument value.
S string
}
}
lockLocate sync.RWMutex
lockRelative sync.RWMutex
lockLocate sync.RWMutex
}
// Locate calls LocateFunc.
@@ -75,7 +63,8 @@ func (mock *LocatorMock) Locate(s string) ([]string, error) {
// LocateCalls gets all the calls that were made to Locate.
// Check the length with:
// len(mockedLocator.LocateCalls())
//
// len(mockedLocator.LocateCalls())
func (mock *LocatorMock) LocateCalls() []struct {
S string
} {
@@ -87,38 +76,3 @@ func (mock *LocatorMock) LocateCalls() []struct {
mock.lockLocate.RUnlock()
return calls
}
// Relative calls RelativeFunc.
func (mock *LocatorMock) Relative(s string) (string, error) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockRelative.Lock()
mock.calls.Relative = append(mock.calls.Relative, callInfo)
mock.lockRelative.Unlock()
if mock.RelativeFunc == nil {
var (
sOut string
errOut error
)
return sOut, errOut
}
return mock.RelativeFunc(s)
}
// RelativeCalls gets all the calls that were made to Relative.
// Check the length with:
// len(mockedLocator.RelativeCalls())
func (mock *LocatorMock) RelativeCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockRelative.RLock()
calls = mock.calls.Relative
mock.lockRelative.RUnlock()
return calls
}

View File

@@ -17,10 +17,8 @@
package modifier
import (
"fmt"
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
@@ -28,8 +26,11 @@ import (
// NewStableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI
// spec. The specified logger is used to capture log output.
func NewStableRuntimeModifier(logger *logrus.Logger) oci.SpecModifier {
m := stableRuntimeModifier{logger: logger}
func NewStableRuntimeModifier(logger *logrus.Logger, nvidiaContainerRuntimeHookPath string) oci.SpecModifier {
m := stableRuntimeModifier{
logger: logger,
nvidiaContainerRuntimeHookPath: nvidiaContainerRuntimeHookPath,
}
return &m
}
@@ -37,7 +38,8 @@ func NewStableRuntimeModifier(logger *logrus.Logger) oci.SpecModifier {
// stableRuntimeModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
// prestart hook. If the hook is already present, no modification is made.
type stableRuntimeModifier struct {
logger *logrus.Logger
logger *logrus.Logger
nvidiaContainerRuntimeHookPath string
}
// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
@@ -53,18 +55,9 @@ func (m stableRuntimeModifier) Modify(spec *specs.Spec) error {
}
}
// We create a locator and look for the NVIDIA Container Runtime Hook in the path.
candidates, err := lookup.NewExecutableLocator(m.logger, "").Locate(config.NVIDIAContainerRuntimeHookExecutable)
if err != nil {
return fmt.Errorf("failed to locate NVIDIA Container Runtime Hook: %v", err)
}
path := candidates[0]
if len(candidates) > 1 {
m.logger.Debugf("Using %v from multiple NVIDIA Container Runtime Hook candidates: %v", path, candidates)
}
path := m.nvidiaContainerRuntimeHookPath
m.logger.Infof("Using prestart hook path: %v", path)
args := []string{path}
args := []string{filepath.Base(path)}
if spec.Hooks == nil {
spec.Hooks = &specs.Hooks{}
}

View File

@@ -79,7 +79,7 @@ func TestAddHookModifier(t *testing.T) {
Prestart: []specs.Hook{
{
Path: testHookPath,
Args: []string{testHookPath, "prestart"},
Args: []string{"nvidia-container-runtime-hook", "prestart"},
},
},
},
@@ -95,7 +95,7 @@ func TestAddHookModifier(t *testing.T) {
Prestart: []specs.Hook{
{
Path: testHookPath,
Args: []string{testHookPath, "prestart"},
Args: []string{"nvidia-container-runtime-hook", "prestart"},
},
},
},
@@ -141,7 +141,7 @@ func TestAddHookModifier(t *testing.T) {
},
{
Path: testHookPath,
Args: []string{testHookPath, "prestart"},
Args: []string{"nvidia-container-runtime-hook", "prestart"},
},
},
},
@@ -154,7 +154,7 @@ func TestAddHookModifier(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
m := NewStableRuntimeModifier(logger)
m := NewStableRuntimeModifier(logger, testHookPath)
err := m.Modify(&tc.spec)
if tc.expectedError != nil {

View File

@@ -13,19 +13,19 @@ var _ Runtime = &RuntimeMock{}
// RuntimeMock is a mock implementation of Runtime.
//
// func TestSomethingThatUsesRuntime(t *testing.T) {
// func TestSomethingThatUsesRuntime(t *testing.T) {
//
// // make and configure a mocked Runtime
// mockedRuntime := &RuntimeMock{
// ExecFunc: func(strings []string) error {
// panic("mock out the Exec method")
// },
// }
// // make and configure a mocked Runtime
// mockedRuntime := &RuntimeMock{
// ExecFunc: func(strings []string) error {
// panic("mock out the Exec method")
// },
// }
//
// // use mockedRuntime in code that requires Runtime
// // and then make assertions.
// // use mockedRuntime in code that requires Runtime
// // and then make assertions.
//
// }
// }
type RuntimeMock struct {
// ExecFunc mocks the Exec method.
ExecFunc func(strings []string) error
@@ -62,7 +62,8 @@ func (mock *RuntimeMock) Exec(strings []string) error {
// ExecCalls gets all the calls that were made to Exec.
// Check the length with:
// len(mockedRuntime.ExecCalls())
//
// len(mockedRuntime.ExecCalls())
func (mock *RuntimeMock) ExecCalls() []struct {
Strings []string
} {

View File

@@ -30,8 +30,9 @@ type SpecModifier interface {
Modify(*specs.Spec) error
}
//go:generate moq -stub -out spec_mock.go . Spec
// Spec defines the operations to be performed on an OCI specification
//
//go:generate moq -stub -out spec_mock.go . Spec
type Spec interface {
Load() (*specs.Spec, error)
Flush() error

View File

@@ -14,28 +14,28 @@ var _ Spec = &SpecMock{}
// SpecMock is a mock implementation of Spec.
//
// func TestSomethingThatUsesSpec(t *testing.T) {
// func TestSomethingThatUsesSpec(t *testing.T) {
//
// // make and configure a mocked Spec
// mockedSpec := &SpecMock{
// FlushFunc: func() error {
// panic("mock out the Flush method")
// },
// LoadFunc: func() (*specs.Spec, error) {
// panic("mock out the Load method")
// },
// LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method")
// },
// ModifyFunc: func(specModifier SpecModifier) error {
// panic("mock out the Modify method")
// },
// }
// // make and configure a mocked Spec
// mockedSpec := &SpecMock{
// FlushFunc: func() error {
// panic("mock out the Flush method")
// },
// LoadFunc: func() (*specs.Spec, error) {
// panic("mock out the Load method")
// },
// LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method")
// },
// ModifyFunc: func(specModifier SpecModifier) error {
// panic("mock out the Modify method")
// },
// }
//
// // use mockedSpec in code that requires Spec
// // and then make assertions.
// // use mockedSpec in code that requires Spec
// // and then make assertions.
//
// }
// }
type SpecMock struct {
// FlushFunc mocks the Flush method.
FlushFunc func() error
@@ -92,7 +92,8 @@ func (mock *SpecMock) Flush() error {
// FlushCalls gets all the calls that were made to Flush.
// Check the length with:
// len(mockedSpec.FlushCalls())
//
// len(mockedSpec.FlushCalls())
func (mock *SpecMock) FlushCalls() []struct {
} {
var calls []struct {
@@ -122,7 +123,8 @@ func (mock *SpecMock) Load() (*specs.Spec, error) {
// LoadCalls gets all the calls that were made to Load.
// Check the length with:
// len(mockedSpec.LoadCalls())
//
// len(mockedSpec.LoadCalls())
func (mock *SpecMock) LoadCalls() []struct {
} {
var calls []struct {
@@ -155,7 +157,8 @@ func (mock *SpecMock) LookupEnv(s string) (string, bool) {
// LookupEnvCalls gets all the calls that were made to LookupEnv.
// Check the length with:
// len(mockedSpec.LookupEnvCalls())
//
// len(mockedSpec.LookupEnvCalls())
func (mock *SpecMock) LookupEnvCalls() []struct {
S string
} {
@@ -189,7 +192,8 @@ func (mock *SpecMock) Modify(specModifier SpecModifier) error {
// ModifyCalls gets all the calls that were made to Modify.
// Check the length with:
// len(mockedSpec.ModifyCalls())
//
// len(mockedSpec.ModifyCalls())
func (mock *SpecMock) ModifyCalls() []struct {
SpecModifier SpecModifier
} {

View File

@@ -13,22 +13,22 @@ var _ Constraint = &ConstraintMock{}
// ConstraintMock is a mock implementation of Constraint.
//
// func TestSomethingThatUsesConstraint(t *testing.T) {
// func TestSomethingThatUsesConstraint(t *testing.T) {
//
// // make and configure a mocked Constraint
// mockedConstraint := &ConstraintMock{
// AssertFunc: func() error {
// panic("mock out the Assert method")
// },
// StringFunc: func() string {
// panic("mock out the String method")
// },
// }
// // make and configure a mocked Constraint
// mockedConstraint := &ConstraintMock{
// AssertFunc: func() error {
// panic("mock out the Assert method")
// },
// StringFunc: func() string {
// panic("mock out the String method")
// },
// }
//
// // use mockedConstraint in code that requires Constraint
// // and then make assertions.
// // use mockedConstraint in code that requires Constraint
// // and then make assertions.
//
// }
// }
type ConstraintMock struct {
// AssertFunc mocks the Assert method.
AssertFunc func() error
@@ -67,7 +67,8 @@ func (mock *ConstraintMock) Assert() error {
// AssertCalls gets all the calls that were made to Assert.
// Check the length with:
// len(mockedConstraint.AssertCalls())
//
// len(mockedConstraint.AssertCalls())
func (mock *ConstraintMock) AssertCalls() []struct {
} {
var calls []struct {
@@ -96,7 +97,8 @@ func (mock *ConstraintMock) String() string {
// StringCalls gets all the calls that were made to String.
// Check the length with:
// len(mockedConstraint.StringCalls())
//
// len(mockedConstraint.StringCalls())
func (mock *ConstraintMock) StringCalls() []struct {
} {
var calls []struct {

View File

@@ -16,8 +16,9 @@
package constraints
//go:generate moq -stub -out constraint_mock.go . Constraint
// Constraint represents a constraint that is to be evaluated
//
//go:generate moq -stub -out constraint_mock.go . Constraint
type Constraint interface {
String() string
Assert() error

View File

@@ -23,8 +23,9 @@ import (
"golang.org/x/mod/semver"
)
//go:generate moq -stub -out property_mock.go . Property
// Property represents a property that is used to check requirements
//
//go:generate moq -stub -out property_mock.go . Property
type Property interface {
Name() string
Value() (string, error)

View File

@@ -13,31 +13,31 @@ var _ Property = &PropertyMock{}
// PropertyMock is a mock implementation of Property.
//
// func TestSomethingThatUsesProperty(t *testing.T) {
// func TestSomethingThatUsesProperty(t *testing.T) {
//
// // make and configure a mocked Property
// mockedProperty := &PropertyMock{
// CompareToFunc: func(s string) (int, error) {
// panic("mock out the CompareTo method")
// },
// NameFunc: func() string {
// panic("mock out the Name method")
// },
// StringFunc: func() string {
// panic("mock out the String method")
// },
// ValidateFunc: func(s string) error {
// panic("mock out the Validate method")
// },
// ValueFunc: func() (string, error) {
// panic("mock out the Value method")
// },
// }
// // make and configure a mocked Property
// mockedProperty := &PropertyMock{
// CompareToFunc: func(s string) (int, error) {
// panic("mock out the CompareTo method")
// },
// NameFunc: func() string {
// panic("mock out the Name method")
// },
// StringFunc: func() string {
// panic("mock out the String method")
// },
// ValidateFunc: func(s string) error {
// panic("mock out the Validate method")
// },
// ValueFunc: func() (string, error) {
// panic("mock out the Value method")
// },
// }
//
// // use mockedProperty in code that requires Property
// // and then make assertions.
// // use mockedProperty in code that requires Property
// // and then make assertions.
//
// }
// }
type PropertyMock struct {
// CompareToFunc mocks the CompareTo method.
CompareToFunc func(s string) (int, error)
@@ -105,7 +105,8 @@ func (mock *PropertyMock) CompareTo(s string) (int, error) {
// CompareToCalls gets all the calls that were made to CompareTo.
// Check the length with:
// len(mockedProperty.CompareToCalls())
//
// len(mockedProperty.CompareToCalls())
func (mock *PropertyMock) CompareToCalls() []struct {
S string
} {
@@ -136,7 +137,8 @@ func (mock *PropertyMock) Name() string {
// NameCalls gets all the calls that were made to Name.
// Check the length with:
// len(mockedProperty.NameCalls())
//
// len(mockedProperty.NameCalls())
func (mock *PropertyMock) NameCalls() []struct {
} {
var calls []struct {
@@ -165,7 +167,8 @@ func (mock *PropertyMock) String() string {
// StringCalls gets all the calls that were made to String.
// Check the length with:
// len(mockedProperty.StringCalls())
//
// len(mockedProperty.StringCalls())
func (mock *PropertyMock) StringCalls() []struct {
} {
var calls []struct {
@@ -197,7 +200,8 @@ func (mock *PropertyMock) Validate(s string) error {
// ValidateCalls gets all the calls that were made to Validate.
// Check the length with:
// len(mockedProperty.ValidateCalls())
//
// len(mockedProperty.ValidateCalls())
func (mock *PropertyMock) ValidateCalls() []struct {
S string
} {
@@ -229,7 +233,8 @@ func (mock *PropertyMock) Value() (string, error) {
// ValueCalls gets all the calls that were made to Value.
// Check the length with:
// len(mockedProperty.ValueCalls())
//
// len(mockedProperty.ValueCalls())
func (mock *PropertyMock) ValueCalls() []struct {
} {
var calls []struct {

View File

@@ -17,6 +17,7 @@
package runtime
import (
"errors"
"fmt"
"io"
"os"
@@ -43,7 +44,7 @@ func NewLogger() *Logger {
}
// Update constructs a Logger with a preddefined formatter
func (l *Logger) Update(filename string, logLevel string, argv []string) error {
func (l *Logger) Update(filename string, logLevel string, argv []string) {
configFromArgs := parseArgs(argv)
@@ -61,7 +62,7 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error {
if !configFromArgs.version {
configLogFile, err := createLogFile(filename)
if err != nil {
return fmt.Errorf("error opening debug log file: %v", err)
argLogFileError = errors.Join(argLogFileError, err)
}
if configLogFile != nil {
logFiles = append(logFiles, configLogFile)
@@ -71,7 +72,7 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error {
if argLogFile != nil {
logFiles = append(logFiles, argLogFile)
}
argLogFileError = err
argLogFileError = errors.Join(argLogFileError, err)
}
defer func() {
if argLogFileError != nil {
@@ -119,8 +120,6 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error {
previousLogger: l.Logger,
logFiles: logFiles,
}
return nil
}
// Reset closes the log file (if any) and resets the logger output to what it
@@ -157,11 +156,16 @@ func (l *Logger) Reset() error {
}
func createLogFile(filename string) (*os.File, error) {
if filename != "" && filename != os.DevNull {
return os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if filename == "" || filename == os.DevNull {
return nil, nil
}
return nil, nil
if dir := filepath.Dir(filepath.Clean(filename)); dir != "." {
err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, err
}
}
return os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
}
type loggerConfig struct {

View File

@@ -44,18 +44,11 @@ func (r rt) Run(argv []string) (rerr error) {
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
if r.modeOverride != "" {
cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride
}
err = r.logger.Update(
r.logger.Update(
cfg.NVIDIAContainerRuntimeConfig.DebugFilePath,
cfg.NVIDIAContainerRuntimeConfig.LogLevel,
argv,
)
if err != nil {
return fmt.Errorf("failed to set up logger: %v", err)
}
defer func() {
if rerr != nil {
r.logger.Errorf("%v", rerr)
@@ -63,6 +56,13 @@ func (r rt) Run(argv []string) (rerr error) {
r.logger.Reset()
}()
// We apply some config updates here to ensure that the config is valid in
// all cases.
if r.modeOverride != "" {
cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride
}
cfg.NVIDIAContainerRuntimeHookConfig.Path = config.ResolveNVIDIAContainerRuntimeHookPath(r.logger.Logger, cfg.NVIDIAContainerRuntimeHookConfig.Path)
// Print the config to the output.
configJSON, err := json.MarshalIndent(cfg, "", " ")
if err == nil {

View File

@@ -61,10 +61,15 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) {
modeModifier, err := newModeModifier(logger, cfg, ociSpec, argv)
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode)
modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, argv)
if err != nil {
return nil, err
}
// For CDI mode we make no additional modifications.
if mode == "cdi" {
return modeModifier, nil
}
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, ociSpec)
if err != nil {
@@ -96,10 +101,10 @@ func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec
return modifiers, nil
}
func newModeModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) {
switch info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) {
func newModeModifier(logger *logrus.Logger, mode string, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) {
switch mode {
case "legacy":
return modifier.NewStableRuntimeModifier(logger), nil
return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil
case "csv":
return modifier.NewCSVModifier(logger, cfg, ociSpec)
case "cdi":

View File

@@ -0,0 +1,154 @@
/**
# Copyright (c) NVIDIA CORPORATIOm. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
"github.com/sirupsen/logrus"
)
var errInvalidDeviceNode = errors.New("invalid device node")
// Interface provides a set of utilities for interacting with NVIDIA devices on the system.
type Interface struct {
devices.Devices
logger *logrus.Logger
dryRun bool
// devRoot is the root directory where device nodes are expected to exist.
devRoot string
mknoder
}
// New constructs a new Interface struct with the specified options.
func New(opts ...Option) (*Interface, error) {
i := &Interface{}
for _, opt := range opts {
opt(i)
}
if i.logger == nil {
i.logger = logrus.StandardLogger()
}
if i.devRoot == "" {
i.devRoot = "/"
}
if i.Devices == nil {
devices, err := devices.GetNVIDIADevices()
if err != nil {
return nil, fmt.Errorf("failed to create devices info: %v", err)
}
i.Devices = devices
}
if i.dryRun {
i.mknoder = &mknodLogger{i.logger}
} else {
i.mknoder = &mknodUnix{}
}
return i, nil
}
// CreateNVIDIAControlDevices creates the NVIDIA control device nodes at the configured devRoot.
func (m *Interface) CreateNVIDIAControlDevices() error {
controlNodes := []string{"nvidiactl", "nvidia-modeset", "nvidia-uvm", "nvidia-uvm-tools"}
for _, node := range controlNodes {
err := m.CreateNVIDIADevice(node)
if err != nil {
return fmt.Errorf("failed to create device node %s: %w", node, err)
}
}
return nil
}
// CreateNVIDIADevice creates the specified NVIDIA device node at the configured devRoot.
func (m *Interface) CreateNVIDIADevice(node string) error {
node = filepath.Base(node)
if !strings.HasPrefix(node, "nvidia") {
return fmt.Errorf("invalid device node %q: %w", node, errInvalidDeviceNode)
}
major, err := m.Major(node)
if err != nil {
return fmt.Errorf("failed to determine major: %w", err)
}
minor, err := m.Minor(node)
if err != nil {
return fmt.Errorf("failed to determine minor: %w", err)
}
return m.createDeviceNode(filepath.Join("dev", node), int(major), int(minor))
}
// createDeviceNode creates the specified device node with the require major and minor numbers.
// If a devRoot is configured, this is prepended to the path.
func (m *Interface) createDeviceNode(path string, major int, minor int) error {
path = filepath.Join(m.devRoot, path)
if _, err := os.Stat(path); err == nil {
m.logger.Infof("Skipping: %s already exists", path)
return nil
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to stat %s: %v", path, err)
}
return m.Mknode(path, major, minor)
}
// Major returns the major number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (m *Interface) Major(node string) (int64, error) {
var valid bool
var major devices.Major
switch node {
case "nvidia-uvm", "nvidia-uvm-tools":
major, valid = m.Get(devices.NVIDIAUVM)
case "nvidia-modeset", "nvidiactl":
major, valid = m.Get(devices.NVIDIAGPU)
}
if valid {
return int64(major), nil
}
return 0, errInvalidDeviceNode
}
// Minor returns the minor number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (m *Interface) Minor(node string) (int64, error) {
switch node {
case "nvidia-modeset":
return devices.NVIDIAModesetMinor, nil
case "nvidia-uvm-tools":
return devices.NVIDIAUVMToolsMinor, nil
case "nvidia-uvm":
return devices.NVIDIAUVMMinor, nil
case "nvidiactl":
return devices.NVIDIACTLMinor, nil
}
return 0, errInvalidDeviceNode
}

View File

@@ -0,0 +1,133 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"errors"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestCreateControlDevices(t *testing.T) {
logger, _ := testlog.NewNullLogger()
nvidiaDevices := &devices.DevicesMock{
GetFunc: func(name devices.Name) (devices.Major, bool) {
devices := map[devices.Name]devices.Major{
"nvidia-frontend": 195,
"nvidia-uvm": 243,
}
return devices[name], true
},
}
mknodeError := errors.New("mknode error")
testCases := []struct {
description string
root string
devices devices.Devices
mknodeError error
expectedError error
expectedCalls []struct {
S string
N1 int
N2 int
}
}{
{
description: "no root specified",
root: "",
devices: nvidiaDevices,
mknodeError: nil,
expectedCalls: []struct {
S string
N1 int
N2 int
}{
{"/dev/nvidiactl", 195, 255},
{"/dev/nvidia-modeset", 195, 254},
{"/dev/nvidia-uvm", 243, 0},
{"/dev/nvidia-uvm-tools", 243, 1},
},
},
{
description: "some root specified",
root: "/some/root",
devices: nvidiaDevices,
mknodeError: nil,
expectedCalls: []struct {
S string
N1 int
N2 int
}{
{"/some/root/dev/nvidiactl", 195, 255},
{"/some/root/dev/nvidia-modeset", 195, 254},
{"/some/root/dev/nvidia-uvm", 243, 0},
{"/some/root/dev/nvidia-uvm-tools", 243, 1},
},
},
{
description: "mknod error returns error",
devices: nvidiaDevices,
mknodeError: mknodeError,
expectedError: mknodeError,
// We expect the first call to this to fail, and the rest to be skipped
expectedCalls: []struct {
S string
N1 int
N2 int
}{
{"/dev/nvidiactl", 195, 255},
},
},
{
description: "missing major returns error",
devices: &devices.DevicesMock{
GetFunc: func(name devices.Name) (devices.Major, bool) {
return 0, false
},
},
expectedError: errInvalidDeviceNode,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
mknode := &mknoderMock{
MknodeFunc: func(string, int, int) error {
return tc.mknodeError
},
}
d, _ := New(
WithLogger(logger),
WithDevRoot(tc.root),
WithDevices(tc.devices),
)
d.mknoder = mknode
err := d.CreateNVIDIAControlDevices()
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expectedCalls, mknode.MknodeCalls())
})
}
}

View File

@@ -0,0 +1,46 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
//go:generate moq -stub -out mknod_mock.go . mknoder
type mknoder interface {
Mknode(string, int, int) error
}
type mknodLogger struct {
*logrus.Logger
}
func (m *mknodLogger) Mknode(path string, major, minor int) error {
m.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor)
return nil
}
type mknodUnix struct{}
func (m *mknodUnix) Mknode(path string, major, minor int) error {
err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor))))
if err != nil {
return err
}
return unix.Chmod(path, 0666)
}

View File

@@ -0,0 +1,89 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvdevices
import (
"sync"
)
// Ensure, that mknoderMock does implement mknoder.
// If this is not the case, regenerate this file with moq.
var _ mknoder = &mknoderMock{}
// mknoderMock is a mock implementation of mknoder.
//
// func TestSomethingThatUsesmknoder(t *testing.T) {
//
// // make and configure a mocked mknoder
// mockedmknoder := &mknoderMock{
// MknodeFunc: func(s string, n1 int, n2 int) error {
// panic("mock out the Mknode method")
// },
// }
//
// // use mockedmknoder in code that requires mknoder
// // and then make assertions.
//
// }
type mknoderMock struct {
// MknodeFunc mocks the Mknode method.
MknodeFunc func(s string, n1 int, n2 int) error
// calls tracks calls to the methods.
calls struct {
// Mknode holds details about calls to the Mknode method.
Mknode []struct {
// S is the s argument value.
S string
// N1 is the n1 argument value.
N1 int
// N2 is the n2 argument value.
N2 int
}
}
lockMknode sync.RWMutex
}
// Mknode calls MknodeFunc.
func (mock *mknoderMock) Mknode(s string, n1 int, n2 int) error {
callInfo := struct {
S string
N1 int
N2 int
}{
S: s,
N1: n1,
N2: n2,
}
mock.lockMknode.Lock()
mock.calls.Mknode = append(mock.calls.Mknode, callInfo)
mock.lockMknode.Unlock()
if mock.MknodeFunc == nil {
var (
errOut error
)
return errOut
}
return mock.MknodeFunc(s, n1, n2)
}
// MknodeCalls gets all the calls that were made to Mknode.
// Check the length with:
//
// len(mockedmknoder.MknodeCalls())
func (mock *mknoderMock) MknodeCalls() []struct {
S string
N1 int
N2 int
} {
var calls []struct {
S string
N1 int
N2 int
}
mock.lockMknode.RLock()
calls = mock.calls.Mknode
mock.lockMknode.RUnlock()
return calls
}

View File

@@ -0,0 +1,53 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
"github.com/sirupsen/logrus"
)
// Option is a function that sets an option on the Interface struct.
type Option func(*Interface)
// WithDryRun sets the dry run option for the Interface struct.
func WithDryRun(dryRun bool) Option {
return func(i *Interface) {
i.dryRun = dryRun
}
}
// WithLogger sets the logger for the Interface struct.
func WithLogger(logger *logrus.Logger) Option {
return func(i *Interface) {
i.logger = logger
}
}
// WithDevRoot sets the root directory for the NVIDIA device nodes.
func WithDevRoot(devRoot string) Option {
return func(i *Interface) {
i.devRoot = devRoot
}
}
// WithDevices sets the devices for the Interface struct.
func WithDevices(devices devices.Devices) Option {
return func(i *Interface) {
i.Devices = devices
}
}

View File

@@ -0,0 +1,49 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvmodules
import (
"fmt"
"os/exec"
"strings"
"github.com/sirupsen/logrus"
)
//go:generate moq -stub -out cmd_mock.go . cmder
type cmder interface {
// Run executes the command and returns the stdout, stderr, and an error if any
Run(string, ...string) error
}
type cmderLogger struct {
*logrus.Logger
}
func (c *cmderLogger) Run(cmd string, args ...string) error {
c.Infof("Running: %v %v", cmd, strings.Join(args, " "))
return nil
}
type cmderExec struct{}
func (c *cmderExec) Run(cmd string, args ...string) error {
if output, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
return fmt.Errorf("%w; output=%v", err, string(output))
}
return nil
}

View File

@@ -0,0 +1,83 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvmodules
import (
"sync"
)
// Ensure, that cmderMock does implement cmder.
// If this is not the case, regenerate this file with moq.
var _ cmder = &cmderMock{}
// cmderMock is a mock implementation of cmder.
//
// func TestSomethingThatUsescmder(t *testing.T) {
//
// // make and configure a mocked cmder
// mockedcmder := &cmderMock{
// RunFunc: func(s string, strings ...string) error {
// panic("mock out the Run method")
// },
// }
//
// // use mockedcmder in code that requires cmder
// // and then make assertions.
//
// }
type cmderMock struct {
// RunFunc mocks the Run method.
RunFunc func(s string, strings ...string) error
// calls tracks calls to the methods.
calls struct {
// Run holds details about calls to the Run method.
Run []struct {
// S is the s argument value.
S string
// Strings is the strings argument value.
Strings []string
}
}
lockRun sync.RWMutex
}
// Run calls RunFunc.
func (mock *cmderMock) Run(s string, strings ...string) error {
callInfo := struct {
S string
Strings []string
}{
S: s,
Strings: strings,
}
mock.lockRun.Lock()
mock.calls.Run = append(mock.calls.Run, callInfo)
mock.lockRun.Unlock()
if mock.RunFunc == nil {
var (
errOut error
)
return errOut
}
return mock.RunFunc(s, strings...)
}
// RunCalls gets all the calls that were made to Run.
// Check the length with:
//
// len(mockedcmder.RunCalls())
func (mock *cmderMock) RunCalls() []struct {
S string
Strings []string
} {
var calls []struct {
S string
Strings []string
}
mock.lockRun.RLock()
calls = mock.calls.Run
mock.lockRun.RUnlock()
return calls
}

View File

@@ -0,0 +1,93 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvmodules
import (
"fmt"
"strings"
"github.com/sirupsen/logrus"
)
// Interface provides a set of utilities for interacting with NVIDIA modules on the system.
type Interface struct {
logger *logrus.Logger
dryRun bool
root string
cmder
}
// New constructs a new Interface struct with the specified options.
func New(opts ...Option) *Interface {
m := &Interface{}
for _, opt := range opts {
opt(m)
}
if m.logger == nil {
m.logger = logrus.StandardLogger()
}
if m.root == "" {
m.root = "/"
}
if m.dryRun {
m.cmder = &cmderLogger{m.logger}
} else {
m.cmder = &cmderExec{}
}
return m
}
// LoadAll loads all the NVIDIA kernel modules.
func (m *Interface) LoadAll() error {
modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"}
for _, module := range modules {
if err := m.Load(module); err != nil {
return fmt.Errorf("failed to load module %s: %w", module, err)
}
}
return nil
}
var errInvalidModule = fmt.Errorf("invalid module")
// Load loads the specified NVIDIA kernel module.
// If the root is specified we first chroot into this root.
func (m *Interface) Load(module string) error {
if !strings.HasPrefix(module, "nvidia") {
return errInvalidModule
}
var args []string
if m.root != "/" {
args = append(args, "chroot", m.root)
}
args = append(args, "/sbin/modprobe", module)
m.logger.Debugf("Loading kernel module %s: %v", module, args)
err := m.Run(args[0], args[1:]...)
if err != nil {
m.logger.Debugf("Failed to load kernel module %s: %v", module, err)
return err
}
return nil
}

View File

@@ -0,0 +1,178 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvmodules
import (
"errors"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestLoadAll(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runError := errors.New("run error")
testCases := []struct {
description string
root string
runError error
expectedError error
expectedCalls []struct {
S string
Strings []string
}
}{
{
description: "no root specified",
root: "",
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
{"/sbin/modprobe", []string{"nvidia-uvm"}},
{"/sbin/modprobe", []string{"nvidia-modeset"}},
},
},
{
description: "root causes chroot",
root: "/some/root",
expectedCalls: []struct {
S string
Strings []string
}{
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}},
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-uvm"}},
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-modeset"}},
},
},
{
description: "run failure is returned",
root: "",
runError: runError,
expectedError: runError,
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
cmder := &cmderMock{
RunFunc: func(cmd string, args ...string) error {
return tc.runError
},
}
m := New(
WithLogger(logger),
WithRoot(tc.root),
)
m.cmder = cmder
err := m.LoadAll()
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expectedCalls, cmder.RunCalls())
})
}
}
func TestLoad(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runError := errors.New("run error")
testCases := []struct {
description string
root string
module string
runError error
expectedError error
expectedCalls []struct {
S string
Strings []string
}
}{
{
description: "no root specified",
root: "",
module: "nvidia",
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
},
},
{
description: "root causes chroot",
root: "/some/root",
module: "nvidia",
expectedCalls: []struct {
S string
Strings []string
}{
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}},
},
},
{
description: "run failure is returned",
root: "",
module: "nvidia",
runError: runError,
expectedError: runError,
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
},
},
{
description: "module prefis is checked",
module: "not-nvidia",
expectedError: errInvalidModule,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
cmder := &cmderMock{
RunFunc: func(cmd string, args ...string) error {
return tc.runError
},
}
m := New(
WithLogger(logger),
WithRoot(tc.root),
)
m.cmder = cmder
err := m.Load(tc.module)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expectedCalls, cmder.RunCalls())
})
}
}

View File

@@ -14,23 +14,30 @@
# limitations under the License.
**/
package system
package nvmodules
import "github.com/sirupsen/logrus"
// Option is a functional option for the system command
// Option is a function that sets an option on the Interface struct.
type Option func(*Interface)
// WithLogger sets the logger for the system command
// WithDryRun sets the dry run option for the Interface struct.
func WithDryRun(dryRun bool) Option {
return func(i *Interface) {
i.dryRun = dryRun
}
}
// WithLogger sets the logger for the Interface struct.
func WithLogger(logger *logrus.Logger) Option {
return func(i *Interface) {
i.logger = logger
}
}
// WithDryRun sets the dry run flag
func WithDryRun(dryRun bool) Option {
// WithRoot sets the root directory for the NVIDIA device nodes.
func WithRoot(root string) Option {
return func(i *Interface) {
i.dryRun = dryRun
i.root = root
}
}

View File

@@ -1,149 +0,0 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package system
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// Interface is the interface for the system command
type Interface struct {
logger *logrus.Logger
dryRun bool
nvidiaDevices nvidiaDevices
}
// New constructs a system command with the specified options
func New(opts ...Option) (*Interface, error) {
i := &Interface{
logger: logrus.StandardLogger(),
}
for _, opt := range opts {
opt(i)
}
devices, err := devices.GetNVIDIADevices()
if err != nil {
return nil, fmt.Errorf("failed to create devices info: %v", err)
}
i.nvidiaDevices = nvidiaDevices{devices}
return i, nil
}
// CreateNVIDIAControlDeviceNodesAt creates the NVIDIA control device nodes associated with the NVIDIA driver at the specified root.
func (m *Interface) CreateNVIDIAControlDeviceNodesAt(root string) error {
controlNodes := []string{"/dev/nvidiactl", "/dev/nvidia-modeset", "/dev/nvidia-uvm", "/dev/nvidia-uvm-tools"}
for _, node := range controlNodes {
path := filepath.Join(root, node)
err := m.CreateNVIDIADeviceNode(path)
if err != nil {
return fmt.Errorf("failed to create device node %s: %v", path, err)
}
}
return nil
}
// CreateNVIDIADeviceNode creates a specified device node associated with the NVIDIA driver.
func (m *Interface) CreateNVIDIADeviceNode(path string) error {
node := filepath.Base(path)
if !strings.HasPrefix(node, "nvidia") {
return fmt.Errorf("invalid device node %q", node)
}
major, err := m.nvidiaDevices.Major(node)
if err != nil {
return fmt.Errorf("failed to determine major: %v", err)
}
minor, err := m.nvidiaDevices.Minor(node)
if err != nil {
return fmt.Errorf("failed to determine minor: %v", err)
}
return m.createDeviceNode(path, int(major), int(minor))
}
func (m *Interface) createDeviceNode(path string, major int, minor int) error {
if m.dryRun {
m.logger.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor)
return nil
}
if _, err := os.Stat(path); err == nil {
m.logger.Infof("Skipping: %s already exists", path)
return nil
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to stat %s: %v", path, err)
}
err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor))))
if err != nil {
return err
}
return unix.Chmod(path, 0666)
}
type nvidiaDevices struct {
devices.Devices
}
// Major returns the major number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (n *nvidiaDevices) Major(node string) (int64, error) {
var valid bool
var major devices.Major
switch node {
case "nvidia-uvm", "nvidia-uvm-tools":
major, valid = n.Get(devices.NVIDIAUVM)
case "nvidia-modeset", "nvidiactl":
major, valid = n.Get(devices.NVIDIAGPU)
}
if !valid {
return 0, fmt.Errorf("invalid device node %q", node)
}
return int64(major), nil
}
// Minor returns the minor number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (n *nvidiaDevices) Minor(node string) (int64, error) {
switch node {
case "nvidia-modeset":
return devices.NVIDIAModesetMinor, nil
case "nvidia-uvm-tools":
return devices.NVIDIAUVMToolsMinor, nil
case "nvidia-uvm":
return devices.NVIDIAUVMMinor, nil
case "nvidiactl":
return devices.NVIDIACTLMinor, nil
}
return 0, fmt.Errorf("invalid device node %q", node)
}

View File

@@ -67,7 +67,6 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
if len(searchPaths) > 1 {
logger.Warnf("Found multiple driver store paths: %v", searchPaths)
}
driverStorePath := searchPaths[0]
searchPaths = append(searchPaths, "/usr/lib/wsl/lib")
libraries := discover.NewMounts(
@@ -83,12 +82,11 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
requiredDriverStoreFiles,
)
// On WSL2 the driver store location is used unchanged.
// For this reason we need to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the driver store.
target := filepath.Join(driverStorePath, "nvidia-smi")
link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links)
symlinkHook := nvidiaSMISimlinkHook{
logger: logger,
mountsFrom: libraries,
nvidiaCTKPath: nvidiaCTKPath,
}
cfg := &discover.Config{
DriverRoot: driverRoot,
@@ -104,3 +102,39 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
return d, nil
}
type nvidiaSMISimlinkHook struct {
discover.None
logger *logrus.Logger
mountsFrom discover.Discover
nvidiaCTKPath string
}
// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store.
// On WSL2 the driver store location is used unchanged, for this reason we need
// to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the
// driver store.
func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
mounts, err := m.mountsFrom.Mounts()
if err != nil {
return nil, fmt.Errorf("failed to discover mounts: %w", err)
}
var target string
for _, mount := range mounts {
if filepath.Base(mount.Path) == "nvidia-smi" {
target = mount.Path
break
}
}
if target == "" {
m.logger.Warningf("Failed to find nvidia-smi in mounts: %v", mounts)
return nil, nil
}
link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCTKPath, links)
return symlinkHook.Hooks()
}

View File

@@ -0,0 +1,163 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvcdi
import (
"errors"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/stretchr/testify/require"
testlog "github.com/sirupsen/logrus/hooks/test"
)
func TestNvidiaSMISymlinkHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
errMounts := errors.New("mounts error")
testCases := []struct {
description string
mounts discover.Discover
expectedError error
expectedHooks []discover.Hook
}{
{
description: "mounts error is returned",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
return nil, errMounts
},
},
expectedError: errMounts,
},
{
description: "no mounts returns no hooks",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
return nil, nil
},
},
},
{
description: "no nvidia-smi returns no hooks",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/not-nvidia-smi"},
{Path: "/also-not-nvidia-smi"},
}
return mounts, nil
},
},
},
{
description: "nvidia-smi must be in path",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/not-nvidia-smi", HostPath: "nvidia-smi"},
{Path: "/also-not-nvidia-smi", HostPath: "not-nvidia-smi"},
}
return mounts, nil
},
},
},
{
description: "nvidia-smi returns hook",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "nvidia-smi"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
{
description: "checks basename",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/some/path/nvidia-smi"},
{Path: "/nvidia-smi/but-not"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
{
description: "returns first match",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/some/path/nvidia-smi"},
{Path: "/another/path/nvidia-smi"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
m := nvidiaSMISimlinkHook{
logger: logger,
mountsFrom: tc.mounts,
nvidiaCTKPath: "nvidia-ctk",
}
devices, err := m.Devices()
require.NoError(t, err)
require.Empty(t, devices)
mounts, err := m.Mounts()
require.NoError(t, err)
require.Empty(t, mounts)
hooks, err := m.Hooks()
require.ErrorIs(t, err, tc.expectedError)
require.Equal(t, tc.expectedHooks, hooks)
})
}
}

View File

@@ -88,7 +88,7 @@ func (m *managementlib) getCudaVersion() (string, error) {
libCudaPaths, err := cuda.New(
cuda.WithLogger(m.logger),
cuda.WithDriverRoot(m.driverRoot),
).Locate(".*.*.*")
).Locate(".*.*")
if err != nil {
return "", fmt.Errorf("failed to locate libcuda.so: %v", err)
}

View File

@@ -18,6 +18,7 @@ package spec
import (
"fmt"
"os"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
@@ -33,6 +34,7 @@ type builder struct {
edits specs.ContainerEdits
format string
noSimplify bool
permissions os.FileMode
}
// newBuilder creates a new spec builder with the supplied options
@@ -60,7 +62,9 @@ func newBuilder(opts ...Option) *builder {
if s.format == "" {
s.format = FormatYAML
}
if s.permissions == 0 {
s.permissions = 0600
}
return s
}
@@ -92,8 +96,9 @@ func (o *builder) Build() (*spec, error) {
}
s := spec{
Spec: raw,
format: o.format,
Spec: raw,
format: o.format,
permissions: o.permissions,
}
return &s, nil
@@ -157,3 +162,10 @@ func WithRawSpec(raw *specs.Spec) Option {
o.raw = raw
}
}
// WithPermissions sets the permissions for the generated spec file
func WithPermissions(permissions os.FileMode) Option {
return func(o *builder) {
o.permissions = permissions
}
}

View File

@@ -28,7 +28,8 @@ import (
type spec struct {
*specs.Spec
format string
format string
permissions os.FileMode
}
var _ Interface = (*spec)(nil)
@@ -51,7 +52,15 @@ func (s *spec) Save(path string) error {
cdi.WithSpecDirs(specDir),
)
return registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path))
if err := registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path)); err != nil {
return fmt.Errorf("failed to write spec: %w", err)
}
if err := os.Chmod(path, s.permissions); err != nil {
return fmt.Errorf("failed to set permissions on spec file: %w", err)
}
return nil
}
// WriteTo writes the spec to the specified writer.

View File

@@ -46,14 +46,6 @@ else
targets=${all[@]}
fi
echo "Updating components"
"${SCRIPTS_DIR}/update-components.sh"
if [[ -n $(git status -s third_party) && ${ALLOW_LOCAL_COMPONENT_CHANGES} != "true" ]]; then
echo "ERROR: Building with local component changes."
echo "Commit pending changes or rerun with ALLOW_LOCAL_COMPONENT_CHANGES='true'"
exit 1
fi
eval $(${SCRIPTS_DIR}/get-component-versions.sh)

View File

@@ -94,19 +94,6 @@ function extract-all() {
local dist=$1
echo "Extracting packages for ${dist} from ${PACKAGE_IMAGE}"
if [ $dist == "ubuntu18.04" ]; then
set -x
# We need to publish the libnvidia-container0 packages to the kitmaker repository as a once off operation.
# We include the packages here so that these will be added to the archive for the ubuntu18.04 arm64 packages.
mkdir -p "${ARTIFACTS_DIR}/packages/ubuntu18.04/arm64/"
curl -L "https://nvidia.github.io/libnvidia-container/ubuntu18.04/arm64/libnvidia-container0_0.10.0+jetpack_arm64.deb" \
--output "${ARTIFACTS_DIR}/packages/ubuntu18.04/arm64/libnvidia-container0_0.10.0+jetpack_arm64.deb"
curl -L "https://nvidia.github.io/libnvidia-container/ubuntu18.04/arm64/libnvidia-container0_0.11.0+jetpack_arm64.deb" \
--output "${ARTIFACTS_DIR}/packages/ubuntu18.04/arm64/libnvidia-container0_0.11.0+jetpack_arm64.deb"
set +x
fi
# Extract every file for the specified dist-arch combiniation in MANIFEST.txt
grep "/${dist}/" "${ARTIFACTS_DIR}/manifest.txt" | while read -r f ; do
package_name="$(basename "$f")"

View File

@@ -27,7 +27,7 @@ testing::crio::hook_created() {
}
testing::crio::hook_cleanup() {
testing::docker_run::toolkit::shell 'crio cleanup'
testing::docker_run::toolkit::shell 'crio cleanup --nvidia-runtime-dir /run/nvidia/toolkit/'
test -z "$(ls -A "${shared_dir}${CRIO_HOOKS_DIR}")"
}

View File

@@ -0,0 +1,169 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package container
import (
"fmt"
"os"
"os/exec"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
)
const (
restartModeNone = "none"
restartModeSignal = "signal"
restartModeSystemd = "systemd"
)
// Options defines the shared options for the CLIs to configure containers runtimes.
type Options struct {
Config string
Socket string
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
}
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context, o *Options) error {
if o.RuntimeDir != "" {
logrus.Debug("Runtime directory already set; ignoring arguments")
return nil
}
args := c.Args()
logrus.Infof("Parsing arguments: %v", args.Slice())
if c.NArg() != 1 {
return fmt.Errorf("incorrect number of arguments")
}
o.RuntimeDir = args.Get(0)
logrus.Infof("Successfully parsed arguments")
return nil
}
// Configure applies the options to the specified config
func (o Options) Configure(cfg engine.Interface) error {
err := o.UpdateConfig(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
return o.flush(cfg)
}
// Unconfigure removes the options from the specified config
func (o Options) Unconfigure(cfg engine.Interface) error {
err := o.RevertConfig(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
return o.flush(cfg)
}
// flush flushes the specified config to disk
func (o Options) flush(cfg engine.Interface) error {
logrus.Infof("Flushing config to %v", o.Config)
n, err := cfg.Save(o.Config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
logrus.Infof("Config file is empty, removed")
}
return nil
}
// UpdateConfig updates the specified config to include the nvidia runtimes
func (o Options) UpdateConfig(cfg engine.Interface) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.RuntimeName),
operator.WithSetAsDefault(o.SetAsDefault),
operator.WithRoot(o.RuntimeDir),
)
for name, runtime := range runtimes {
err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("failed to update runtime %q: %v", name, err)
}
}
return nil
}
// RevertConfig reverts the specified config to remove the nvidia runtimes
func (o Options) RevertConfig(cfg engine.Interface) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.RuntimeName),
operator.WithSetAsDefault(o.SetAsDefault),
operator.WithRoot(o.RuntimeDir),
)
for name := range runtimes {
err := cfg.RemoveRuntime(name)
if err != nil {
return fmt.Errorf("failed to remove runtime %q: %v", name, err)
}
}
return nil
}
// Restart restarts the specified service
func (o Options) Restart(service string, withSignal func(string) error) error {
switch o.RestartMode {
case restartModeNone:
logrus.Warnf("Skipping restart of %v due to --restart-mode=%v", service, o.RestartMode)
return nil
case restartModeSignal:
return withSignal(o.Socket)
case restartModeSystemd:
return o.SystemdRestart(service)
}
return fmt.Errorf("invalid restart mode specified: %v", o.RestartMode)
}
// SystemdRestart restarts the specified service using systemd
func (o Options) SystemdRestart(service string) error {
var args []string
var msg string
if o.HostRootMount != "" {
msg = " on host"
args = append(args, "chroot", o.HostRootMount)
}
args = append(args, "systemctl", "restart", service)
logrus.Infof("Restarting %v%v using systemd: %v", service, msg, args)
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error restarting %v using systemd: %v", service, err)
}
return nil
}

View File

@@ -21,6 +21,7 @@ import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
"github.com/pelletier/go-toml"
"github.com/stretchr/testify/require"
)
@@ -31,7 +32,7 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
testCases := []struct {
legacyConfig bool
setAsDefault bool
runtimeClass string
runtimeName string
expectedDefaultRuntimeName interface{}
expectedDefaultRuntimeBinary interface{}
}{
@@ -51,14 +52,14 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
{
legacyConfig: true,
setAsDefault: true,
runtimeClass: "NAME",
runtimeName: "NAME",
expectedDefaultRuntimeName: nil,
expectedDefaultRuntimeBinary: "/test/runtime/dir/nvidia-container-runtime",
},
{
legacyConfig: true,
setAsDefault: true,
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: nil,
expectedDefaultRuntimeBinary: "/test/runtime/dir/nvidia-container-runtime.experimental",
},
@@ -77,14 +78,14 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
{
legacyConfig: false,
setAsDefault: true,
runtimeClass: "NAME",
runtimeName: "NAME",
expectedDefaultRuntimeName: "NAME",
expectedDefaultRuntimeBinary: nil,
},
{
legacyConfig: false,
setAsDefault: true,
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: "nvidia-experimental",
expectedDefaultRuntimeBinary: nil,
},
@@ -93,11 +94,13 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
useLegacyConfig: tc.legacyConfig,
setAsDefault: tc.setAsDefault,
runtimeClass: tc.runtimeClass,
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
runtimeType: runtimeType,
runtimeDir: runtimeDir,
useLegacyConfig: tc.legacyConfig,
}
config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -109,7 +112,7 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
RuntimeType: runtimeType,
}
err = UpdateConfig(v1, o)
err = o.UpdateConfig(v1)
require.NoError(t, err, "%d: %v", i, tc)
defaultRuntimeName := v1.GetPath([]string{"plugins", "cri", "containerd", "default_runtime_name"})
@@ -138,11 +141,11 @@ func TestUpdateV1Config(t *testing.T) {
const runtimeDir = "/test/runtime/dir"
testCases := []struct {
runtimeClass string
runtimeName string
expectedConfig map[string]interface{}
}{
{
runtimeClass: "nvidia",
runtimeName: "nvidia",
expectedConfig: map[string]interface{}{
"version": int64(1),
"plugins": map[string]interface{}{
@@ -200,7 +203,7 @@ func TestUpdateV1Config(t *testing.T) {
},
},
{
runtimeClass: "NAME",
runtimeName: "NAME",
expectedConfig: map[string]interface{}{
"version": int64(1),
"plugins": map[string]interface{}{
@@ -258,7 +261,7 @@ func TestUpdateV1Config(t *testing.T) {
},
},
{
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{
"version": int64(1),
"plugins": map[string]interface{}{
@@ -320,9 +323,10 @@ func TestUpdateV1Config(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
runtimeClass: tc.runtimeClass,
runtimeType: runtimeType,
runtimeDir: runtimeDir,
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
},
}
config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -335,7 +339,7 @@ func TestUpdateV1Config(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"},
}
err = UpdateConfig(v1, o)
err = o.UpdateConfig(v1)
require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -350,11 +354,11 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
const runtimeDir = "/test/runtime/dir"
testCases := []struct {
runtimeClass string
runtimeName string
expectedConfig map[string]interface{}
}{
{
runtimeClass: "nvidia",
runtimeName: "nvidia",
expectedConfig: map[string]interface{}{
"version": int64(1),
"plugins": map[string]interface{}{
@@ -426,7 +430,7 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
},
},
{
runtimeClass: "NAME",
runtimeName: "NAME",
expectedConfig: map[string]interface{}{
"version": int64(1),
"plugins": map[string]interface{}{
@@ -498,7 +502,7 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
},
},
{
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{
"version": int64(1),
"plugins": map[string]interface{}{
@@ -574,9 +578,10 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
runtimeClass: tc.runtimeClass,
runtimeType: runtimeType,
runtimeDir: runtimeDir,
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
},
}
config, err := toml.TreeFromMap(runcConfigMapV1("/runc-binary"))
@@ -589,7 +594,7 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"},
}
err = UpdateConfig(v1, o)
err = o.UpdateConfig(v1)
require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -653,7 +658,9 @@ func TestRevertV1Config(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
runtimeClass: "nvidia",
Options: container.Options{
RuntimeName: "nvidia",
},
}
config, err := toml.TreeFromMap(tc.config)
@@ -668,7 +675,7 @@ func TestRevertV1Config(t *testing.T) {
RuntimeType: runtimeType,
}
err = RevertConfig(v1, o)
err = o.RevertConfig(v1)
require.NoError(t, err, "%d: %v", i, tc)
configContents, _ := toml.Marshal(config)

View File

@@ -21,6 +21,7 @@ import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
"github.com/pelletier/go-toml"
"github.com/stretchr/testify/require"
)
@@ -34,38 +35,38 @@ func TestUpdateV2ConfigDefaultRuntime(t *testing.T) {
testCases := []struct {
setAsDefault bool
runtimeClass string
runtimeName string
expectedDefaultRuntimeName interface{}
}{
{},
{
setAsDefault: false,
runtimeClass: "nvidia",
runtimeName: "nvidia",
expectedDefaultRuntimeName: nil,
},
{
setAsDefault: false,
runtimeClass: "NAME",
runtimeName: "NAME",
expectedDefaultRuntimeName: nil,
},
{
setAsDefault: false,
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: nil,
},
{
setAsDefault: true,
runtimeClass: "nvidia",
runtimeName: "nvidia",
expectedDefaultRuntimeName: "nvidia",
},
{
setAsDefault: true,
runtimeClass: "NAME",
runtimeName: "NAME",
expectedDefaultRuntimeName: "NAME",
},
{
setAsDefault: true,
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: "nvidia-experimental",
},
}
@@ -73,9 +74,11 @@ func TestUpdateV2ConfigDefaultRuntime(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
setAsDefault: tc.setAsDefault,
runtimeClass: tc.runtimeClass,
runtimeDir: runtimeDir,
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
}
config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -86,7 +89,7 @@ func TestUpdateV2ConfigDefaultRuntime(t *testing.T) {
RuntimeType: runtimeType,
}
err = UpdateConfig(v2, o)
err = o.UpdateConfig(v2)
require.NoError(t, err)
defaultRuntimeName := config.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "default_runtime_name"})
@@ -100,11 +103,11 @@ func TestUpdateV2Config(t *testing.T) {
const expectedVersion = int64(2)
testCases := []struct {
runtimeClass string
runtimeName string
expectedConfig map[string]interface{}
}{
{
runtimeClass: "nvidia",
runtimeName: "nvidia",
expectedConfig: map[string]interface{}{
"version": int64(2),
"plugins": map[string]interface{}{
@@ -158,7 +161,7 @@ func TestUpdateV2Config(t *testing.T) {
},
},
{
runtimeClass: "NAME",
runtimeName: "NAME",
expectedConfig: map[string]interface{}{
"version": int64(2),
"plugins": map[string]interface{}{
@@ -212,7 +215,7 @@ func TestUpdateV2Config(t *testing.T) {
},
},
{
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{
"version": int64(2),
"plugins": map[string]interface{}{
@@ -270,9 +273,11 @@ func TestUpdateV2Config(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
runtimeClass: tc.runtimeClass,
runtimeType: runtimeType,
runtimeDir: runtimeDir,
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
},
runtimeType: runtimeType,
}
config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -284,7 +289,7 @@ func TestUpdateV2Config(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"},
}
err = UpdateConfig(v2, o)
err = o.UpdateConfig(v2)
require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -300,11 +305,11 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
const runtimeDir = "/test/runtime/dir"
testCases := []struct {
runtimeClass string
runtimeName string
expectedConfig map[string]interface{}
}{
{
runtimeClass: "nvidia",
runtimeName: "nvidia",
expectedConfig: map[string]interface{}{
"version": int64(2),
"plugins": map[string]interface{}{
@@ -372,7 +377,7 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
},
},
{
runtimeClass: "NAME",
runtimeName: "NAME",
expectedConfig: map[string]interface{}{
"version": int64(2),
"plugins": map[string]interface{}{
@@ -440,7 +445,7 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
},
},
{
runtimeClass: "nvidia-experimental",
runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{
"version": int64(2),
"plugins": map[string]interface{}{
@@ -512,9 +517,10 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
runtimeClass: tc.runtimeClass,
runtimeType: runtimeType,
runtimeDir: runtimeDir,
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
},
}
config, err := toml.TreeFromMap(runcConfigMapV2("/runc-binary"))
@@ -526,7 +532,7 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"},
}
err = UpdateConfig(v2, o)
err = o.UpdateConfig(v2)
require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -585,7 +591,9 @@ func TestRevertV2Config(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{
runtimeClass: "nvidia",
Options: container.Options{
RuntimeName: "nvidia",
},
}
config, err := toml.TreeFromMap(tc.config)
@@ -599,7 +607,7 @@ func TestRevertV2Config(t *testing.T) {
RuntimeType: runtimeType,
}
err = RevertConfig(v2, o)
err = o.RevertConfig(v2)
require.NoError(t, err)
configContents, _ := toml.Marshal(config)

View File

@@ -20,33 +20,23 @@ import (
"fmt"
"net"
"os"
"os/exec"
"syscall"
"time"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
log "github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2"
)
const (
restartModeSignal = "signal"
restartModeSystemd = "systemd"
restartModeNone = "none"
nvidiaRuntimeName = "nvidia"
nvidiaRuntimeBinary = "nvidia-container-runtime"
nvidiaExperimentalRuntimeName = "nvidia-experimental"
nvidiaExperimentalRuntimeBinary = "nvidia-container-runtime.experimental"
defaultConfig = "/etc/containerd/config.toml"
defaultSocket = "/run/containerd/containerd.sock"
defaultRuntimeClass = "nvidia"
defaultRuntmeType = "io.containerd.runc.v2"
defaultSetAsDefault = true
defaultRestartMode = restartModeSignal
defaultRestartMode = "signal"
defaultHostRootMount = "/host"
reloadBackoff = 5 * time.Second
@@ -55,23 +45,13 @@ const (
socketMessageToGetPID = ""
)
// nvidiaRuntimeBinaries defines a map of runtime names to binary names
var nvidiaRuntimeBinaries = map[string]string{
nvidiaRuntimeName: nvidiaRuntimeBinary,
nvidiaExperimentalRuntimeName: nvidiaExperimentalRuntimeBinary,
}
// options stores the configuration from the command line or environment variables
type options struct {
config string
socket string
runtimeClass string
runtimeType string
setAsDefault bool
restartMode string
hostRootMount string
runtimeDir string
container.Options
// containerd-specific options
useLegacyConfig bool
runtimeType string
ContainerRuntimeModesCDIAnnotationPrefixes cli.StringSlice
}
@@ -83,7 +63,7 @@ func main() {
c := cli.NewApp()
c.Name = "containerd"
c.Usage = "Update a containerd config with the nvidia-container-runtime"
c.Version = "0.1.0"
c.Version = info.GetVersionString()
// Create the 'setup' subcommand
setup := cli.Command{}
@@ -93,6 +73,9 @@ func main() {
setup.Action = func(c *cli.Context) error {
return Setup(c, &options)
}
setup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Create the 'cleanup' subcommand
cleanup := cli.Command{}
@@ -102,6 +85,9 @@ func main() {
cleanup.Action = func(c *cli.Context) error {
return Cleanup(c, &options)
}
cleanup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Register the subcommands with the top-level CLI
c.Commands = []*cli.Command{
@@ -116,57 +102,53 @@ func main() {
commonFlags := []cli.Flag{
&cli.StringFlag{
Name: "config",
Aliases: []string{"c"},
Usage: "Path to the containerd config file",
Value: defaultConfig,
Destination: &options.config,
EnvVars: []string{"CONTAINERD_CONFIG"},
Destination: &options.Config,
EnvVars: []string{"RUNTIME_CONFIG", "CONTAINERD_CONFIG"},
},
&cli.StringFlag{
Name: "socket",
Aliases: []string{"s"},
Usage: "Path to the containerd socket file",
Value: defaultSocket,
Destination: &options.socket,
EnvVars: []string{"CONTAINERD_SOCKET"},
},
&cli.StringFlag{
Name: "runtime-class",
Aliases: []string{"r"},
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.runtimeClass,
EnvVars: []string{"CONTAINERD_RUNTIME_CLASS"},
},
&cli.StringFlag{
Name: "runtime-type",
Usage: "The runtime_type to use for the configured runtime classes",
Value: defaultRuntmeType,
Destination: &options.runtimeType,
EnvVars: []string{"CONTAINERD_RUNTIME_TYPE"},
},
// The flags below are only used by the 'setup' command.
&cli.BoolFlag{
Name: "set-as-default",
Aliases: []string{"d"},
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.setAsDefault,
EnvVars: []string{"CONTAINERD_SET_AS_DEFAULT"},
Hidden: true,
Destination: &options.Socket,
EnvVars: []string{"RUNTIME_SOCKET", "CONTAINERD_SOCKET"},
},
&cli.StringFlag{
Name: "restart-mode",
Usage: "Specify how containerd should be restarted; If 'none' is selected, it will not be restarted [signal | systemd | none]",
Value: defaultRestartMode,
Destination: &options.restartMode,
EnvVars: []string{"CONTAINERD_RESTART_MODE"},
Destination: &options.RestartMode,
EnvVars: []string{"RUNTIME_RESTART_MODE", "CONTAINERD_RESTART_MODE"},
},
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"nvidia-runtime-name", "runtime-class"},
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.RuntimeName,
EnvVars: []string{"NVIDIA_RUNTIME_NAME", "CONTAINERD_RUNTIME_CLASS"},
},
&cli.StringFlag{
Name: "nvidia-runtime-dir",
Aliases: []string{"runtime-dir"},
Usage: "The path where the nvidia-container-runtime binaries are located. If this is not specified, the first argument will be used instead",
Destination: &options.RuntimeDir,
EnvVars: []string{"NVIDIA_RUNTIME_DIR"},
},
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.SetAsDefault,
EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "CONTAINERD_SET_AS_DEFAULT"},
Hidden: true,
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting containerd using systemd",
Value: defaultHostRootMount,
Destination: &options.hostRootMount,
Destination: &options.HostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"},
},
&cli.BoolFlag{
@@ -175,6 +157,13 @@ func main() {
Destination: &options.useLegacyConfig,
EnvVars: []string{"CONTAINERD_USE_LEGACY_CONFIG"},
},
&cli.StringFlag{
Name: "runtime-type",
Usage: "The runtime_type to use for the configured runtime classes",
Value: defaultRuntmeType,
Destination: &options.runtimeType,
EnvVars: []string{"CONTAINERD_RUNTIME_TYPE"},
},
&cli.StringSliceFlag{
Name: "nvidia-container-runtime-modes.cdi.annotation-prefixes",
Destination: &options.ContainerRuntimeModesCDIAnnotationPrefixes,
@@ -196,14 +185,8 @@ func main() {
func Setup(c *cli.Context, o *options) error {
log.Infof("Starting 'setup' for %v", c.App.Name)
runtimeDir, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
o.runtimeDir = runtimeDir
cfg, err := containerd.New(
containerd.WithPath(o.config),
containerd.WithPath(o.Config),
containerd.WithRuntimeType(o.runtimeType),
containerd.WithUseLegacyConfig(o.useLegacyConfig),
containerd.WithContainerAnnotations(o.containerAnnotationsFromCDIPrefixes()...),
@@ -212,18 +195,9 @@ func Setup(c *cli.Context, o *options) error {
return fmt.Errorf("unable to load config: %v", err)
}
err = UpdateConfig(cfg, o)
err = o.Configure(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
log.Infof("Flushing containerd config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
return fmt.Errorf("unable to configure containerd: %v", err)
}
err = RestartContainerd(o)
@@ -240,13 +214,8 @@ func Setup(c *cli.Context, o *options) error {
func Cleanup(c *cli.Context, o *options) error {
log.Infof("Starting 'cleanup' for %v", c.App.Name)
_, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
cfg, err := containerd.New(
containerd.WithPath(o.config),
containerd.WithPath(o.Config),
containerd.WithRuntimeType(o.runtimeType),
containerd.WithUseLegacyConfig(o.useLegacyConfig),
containerd.WithContainerAnnotations(o.containerAnnotationsFromCDIPrefixes()...),
@@ -255,18 +224,9 @@ func Cleanup(c *cli.Context, o *options) error {
return fmt.Errorf("unable to load config: %v", err)
}
err = RevertConfig(cfg, o)
err = o.Unconfigure(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
log.Infof("Flushing containerd config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
return fmt.Errorf("unable to unconfigure containerd: %v", err)
}
err = RestartContainerd(o)
@@ -279,80 +239,18 @@ func Cleanup(c *cli.Context, o *options) error {
return nil
}
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context) (string, error) {
args := c.Args()
log.Infof("Parsing arguments: %v", args.Slice())
if args.Len() != 1 {
return "", fmt.Errorf("incorrect number of arguments")
}
runtimeDir := args.Get(0)
log.Infof("Successfully parsed arguments")
return runtimeDir, nil
}
// UpdateConfig updates the containerd config to include the nvidia-container-runtime
func UpdateConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class, runtime := range runtimes {
err := cfg.AddRuntime(class, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("unable to update config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RevertConfig reverts the containerd config to remove the nvidia-container-runtime
func RevertConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class := range runtimes {
err := cfg.RemoveRuntime(class)
if err != nil {
return fmt.Errorf("unable to revert config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RestartContainerd restarts containerd depending on the value of restartModeFlag
func RestartContainerd(o *options) error {
switch o.restartMode {
case restartModeNone:
log.Warnf("Skipping sending signal to containerd due to --restart-mode=%v", o.restartMode)
return nil
case restartModeSignal:
err := SignalContainerd(o)
if err != nil {
return fmt.Errorf("unable to signal containerd: %v", err)
}
case restartModeSystemd:
return RestartContainerdSystemd(o.hostRootMount)
default:
return fmt.Errorf("Invalid restart mode specified: %v", o.restartMode)
}
return nil
return o.Restart("containerd", SignalContainerd)
}
// SignalContainerd sends a SIGHUP signal to the containerd daemon
func SignalContainerd(o *options) error {
func SignalContainerd(socket string) error {
log.Infof("Sending SIGHUP signal to containerd")
// Wrap the logic to perform the SIGHUP in a function so we can retry it on failure
retriable := func() error {
conn, err := net.Dial("unix", o.socket)
conn, err := net.Dial("unix", socket)
if err != nil {
return fmt.Errorf("unable to dial: %v", err)
}
@@ -426,24 +324,6 @@ func SignalContainerd(o *options) error {
return nil
}
// RestartContainerdSystemd restarts containerd using systemctl
func RestartContainerdSystemd(hostRootMount string) error {
log.Infof("Restarting containerd using systemd and host root mounted at %v", hostRootMount)
command := "chroot"
args := []string{hostRootMount, "systemctl", "restart", "containerd"}
cmd := exec.Command(command, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error restarting containerd using systemd: %v", err)
}
return nil
}
// containerAnnotationsFromCDIPrefixes returns the container annotations to set for the given CDI prefixes.
func (o *options) containerAnnotationsFromCDIPrefixes() []string {
var annotations []string

View File

@@ -20,21 +20,17 @@ import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/crio"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
log "github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2"
)
const (
restartModeSystemd = "systemd"
restartModeNone = "none"
defaultConfigMode = "hook"
// Hook-based settings
@@ -43,25 +39,22 @@ const (
// Config-based settings
defaultConfig = "/etc/crio/crio.conf"
defaultSocket = "/var/run/crio/crio.sock"
defaultRuntimeClass = "nvidia"
defaultSetAsDefault = true
defaultRestartMode = restartModeSystemd
defaultRestartMode = "systemd"
defaultHostRootMount = "/host"
)
// options stores the configuration from the command linek or environment variables
type options struct {
container.Options
configMode string
// hook-specific options
hooksDir string
hookFilename string
runtimeDir string
config string
runtimeClass string
setAsDefault bool
restartMode string
hostRootMount string
}
func main() {
@@ -71,8 +64,7 @@ func main() {
c := cli.NewApp()
c.Name = "crio"
c.Usage = "Update cri-o hooks to include the NVIDIA runtime hook"
c.ArgsUsage = "<toolkit_dirname>"
c.Version = "0.1.0"
c.Version = info.GetVersionString()
// Create the 'setup' subcommand
setup := cli.Command{}
@@ -83,7 +75,7 @@ func main() {
return Setup(c, &options)
}
setup.Before = func(c *cli.Context) error {
return ParseArgs(c, &options)
return container.ParseArgs(c, &options.Options)
}
// Create the 'cleanup' subcommand
@@ -93,6 +85,10 @@ func main() {
cleanup.Action = func(c *cli.Context) error {
return Cleanup(c, &options)
}
cleanup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Register the subcommands with the top-level CLI
c.Commands = []*cli.Command{
&setup,
@@ -104,9 +100,61 @@ func main() {
// only require the user to specify one set of flags for both 'startup'
// and 'cleanup' to simplify things.
commonFlags := []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "Path to the cri-o config file",
Value: defaultConfig,
Destination: &options.Config,
EnvVars: []string{"RUNTIME_CONFIG", "CRIO_CONFIG"},
},
&cli.StringFlag{
Name: "socket",
Usage: "Path to the crio socket file",
Value: "",
Destination: &options.Socket,
EnvVars: []string{"RUNTIME_SOCKET", "CRIO_SOCKET"},
// Note: We hide this option since restarting cri-o via a socket is not supported.
Hidden: true,
},
&cli.StringFlag{
Name: "restart-mode",
Usage: "Specify how cri-o should be restarted; If 'none' is selected, it will not be restarted [systemd | none]",
Value: defaultRestartMode,
Destination: &options.RestartMode,
EnvVars: []string{"RUNTIME_RESTART_MODE", "CRIO_RESTART_MODE"},
},
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"nvidia-runtime-name", "runtime-class"},
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.RuntimeName,
EnvVars: []string{"NVIDIA_RUNTIME_NAME", "CRIO_RUNTIME_CLASS"},
},
&cli.StringFlag{
Name: "nvidia-runtime-dir",
Aliases: []string{"runtime-dir"},
Usage: "The path where the nvidia-container-runtime binaries are located. If this is not specified, the first argument will be used instead",
Destination: &options.RuntimeDir,
EnvVars: []string{"NVIDIA_RUNTIME_DIR"},
},
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.SetAsDefault,
EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "CRIO_SET_AS_DEFAULT"},
Hidden: true,
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting crio using systemd",
Value: defaultHostRootMount,
Destination: &options.HostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"},
},
&cli.StringFlag{
Name: "hooks-dir",
Aliases: []string{"d"},
Usage: "path to the cri-o hooks directory",
Value: defaultHooksDir,
Destination: &options.hooksDir,
@@ -115,7 +163,6 @@ func main() {
},
&cli.StringFlag{
Name: "hook-filename",
Aliases: []string{"f"},
Usage: "filename of the cri-o hook that will be created / removed in the hooks directory",
Value: defaultHookFilename,
Destination: &options.hookFilename,
@@ -129,43 +176,6 @@ func main() {
Destination: &options.configMode,
EnvVars: []string{"CRIO_CONFIG_MODE"},
},
&cli.StringFlag{
Name: "config",
Usage: "Path to the cri-o config file",
Value: defaultConfig,
Destination: &options.config,
EnvVars: []string{"CRIO_CONFIG"},
},
&cli.StringFlag{
Name: "runtime-class",
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.runtimeClass,
EnvVars: []string{"CRIO_RUNTIME_CLASS"},
},
// The flags below are only used by the 'setup' command.
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.setAsDefault,
EnvVars: []string{"CRIO_SET_AS_DEFAULT"},
Hidden: true,
},
&cli.StringFlag{
Name: "restart-mode",
Usage: "Specify how cri-o should be restarted; If 'none' is selected, it will not be restarted [systemd | none]",
Value: defaultRestartMode,
Destination: &options.restartMode,
EnvVars: []string{"CRIO_RESTART_MODE"},
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting crio using systemd",
Value: defaultHostRootMount,
Destination: &options.hostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"},
},
}
// Update the subcommand flags with the common subcommand flags
@@ -202,7 +212,7 @@ func setupHook(o *options) error {
}
hookPath := getHookPath(o.hooksDir, o.hookFilename)
err = createHook(o.runtimeDir, hookPath)
err = createHook(o.RuntimeDir, hookPath)
if err != nil {
return fmt.Errorf("error creating hook: %v", err)
}
@@ -215,24 +225,15 @@ func setupConfig(o *options) error {
log.Infof("Updating config file")
cfg, err := crio.New(
crio.WithPath(o.config),
crio.WithPath(o.Config),
)
if err != nil {
return fmt.Errorf("unable to load config: %v", err)
}
err = UpdateConfig(cfg, o)
err = o.Configure(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
log.Infof("Flushing cri-o config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
return fmt.Errorf("unable to configure cri-o: %v", err)
}
err = RestartCrio(o)
@@ -275,24 +276,15 @@ func cleanupConfig(o *options) error {
log.Infof("Reverting config file modifications")
cfg, err := crio.New(
crio.WithPath(o.config),
crio.WithPath(o.Config),
)
if err != nil {
return fmt.Errorf("unable to load config: %v", err)
}
err = RevertConfig(cfg, o)
err = o.Unconfigure(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
log.Infof("Flushing cri-o config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
return fmt.Errorf("unable to unconfigure cri-o: %v", err)
}
err = RestartCrio(o)
@@ -303,20 +295,6 @@ func cleanupConfig(o *options) error {
return nil
}
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context, o *options) error {
args := c.Args()
log.Infof("Parsing arguments: %v", args.Slice())
if c.NArg() != 1 {
return fmt.Errorf("incorrect number of arguments")
}
o.runtimeDir = args.Get(0)
log.Infof("Successfully parsed arguments")
return nil
}
func createHook(toolkitDir string, hookPath string) error {
hook, err := os.Create(hookPath)
if err != nil {
@@ -357,66 +335,7 @@ func generateOciHook(toolkitDir string) podmanHook {
return hook
}
// UpdateConfig updates the cri-o config to include the NVIDIA Container Runtime
func UpdateConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class, runtime := range runtimes {
err := cfg.AddRuntime(class, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("unable to update config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RevertConfig reverts the cri-o config to remove the NVIDIA Container Runtime
func RevertConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class := range runtimes {
err := cfg.RemoveRuntime(class)
if err != nil {
return fmt.Errorf("unable to revert config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RestartCrio restarts crio depending on the value of restartModeFlag
func RestartCrio(o *options) error {
switch o.restartMode {
case restartModeNone:
log.Warnf("Skipping restart of crio due to --restart-mode=%v", o.restartMode)
return nil
case restartModeSystemd:
return RestartCrioSystemd(o.hostRootMount)
default:
return fmt.Errorf("invalid restart mode specified: %v", o.restartMode)
}
}
// RestartCrioSystemd restarts cri-o using systemctl
func RestartCrioSystemd(hostRootMount string) error {
log.Infof("Restarting cri-o using systemd and host root mounted at %v", hostRootMount)
command := "chroot"
args := []string{hostRootMount, "systemctl", "restart", "crio"}
cmd := exec.Command(command, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error restarting crio using systemd: %v", err)
}
return nil
return o.Restart("crio", func(string) error { return fmt.Errorf("supporting crio via signal is unsupported") })
}

View File

@@ -23,50 +23,31 @@ import (
"syscall"
"time"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/docker"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
log "github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2"
)
const (
restartModeSignal = "signal"
restartModeNone = "none"
nvidiaRuntimeName = "nvidia"
nvidiaRuntimeBinary = "nvidia-container-runtime"
nvidiaExperimentalRuntimeName = "nvidia-experimental"
nvidiaExperimentalRuntimeBinary = "nvidia-container-runtime.experimental"
defaultConfig = "/etc/docker/daemon.json"
defaultSocket = "/var/run/docker.sock"
defaultSetAsDefault = true
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
defaultRuntimeName = nvidiaRuntimeName
defaultRestartMode = restartModeSignal
defaultRuntimeName = "nvidia"
defaultRestartMode = "signal"
defaultHostRootMount = "/host"
reloadBackoff = 5 * time.Second
maxReloadAttempts = 6
defaultDockerRuntime = "runc"
socketMessageToGetPID = "GET /info HTTP/1.0\r\n\r\n"
)
// nvidiaRuntimeBinaries defines a map of runtime names to binary names
var nvidiaRuntimeBinaries = map[string]string{
nvidiaRuntimeName: nvidiaRuntimeBinary,
nvidiaExperimentalRuntimeName: nvidiaExperimentalRuntimeBinary,
}
// options stores the configuration from the command line or environment variables
type options struct {
config string
socket string
runtimeName string
setAsDefault bool
runtimeDir string
restartMode string
container.Options
}
func main() {
@@ -76,7 +57,7 @@ func main() {
c := cli.NewApp()
c.Name = "docker"
c.Usage = "Update docker config with the nvidia runtime"
c.Version = "0.1.0"
c.Version = info.GetVersionString()
// Create the 'setup' subcommand
setup := cli.Command{}
@@ -86,6 +67,9 @@ func main() {
setup.Action = func(c *cli.Context) error {
return Setup(c, &options)
}
setup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Create the 'cleanup' subcommand
cleanup := cli.Command{}
@@ -95,6 +79,9 @@ func main() {
cleanup.Action = func(c *cli.Context) error {
return Cleanup(c, &options)
}
cleanup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Register the subcommands with the top-level CLI
c.Commands = []*cli.Command{
@@ -109,44 +96,57 @@ func main() {
commonFlags := []cli.Flag{
&cli.StringFlag{
Name: "config",
Aliases: []string{"c"},
Usage: "Path to docker config file",
Value: defaultConfig,
Destination: &options.config,
EnvVars: []string{"DOCKER_CONFIG"},
Destination: &options.Config,
EnvVars: []string{"RUNTIME_CONFIG", "DOCKER_CONFIG"},
},
&cli.StringFlag{
Name: "socket",
Aliases: []string{"s"},
Usage: "Path to the docker socket file",
Value: defaultSocket,
Destination: &options.socket,
EnvVars: []string{"DOCKER_SOCKET"},
},
// The flags below are only used by the 'setup' command.
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"r"},
Usage: "Specify the name of the `nvidia` runtime. If set-as-default is selected, the runtime is used as the default runtime.",
Value: defaultRuntimeName,
Destination: &options.runtimeName,
EnvVars: []string{"DOCKER_RUNTIME_NAME"},
},
&cli.BoolFlag{
Name: "set-as-default",
Aliases: []string{"d"},
Usage: "Set the `nvidia` runtime as the default runtime. If --runtime-name is specified as `nvidia-experimental` the experimental runtime is set as the default runtime instead",
Value: defaultSetAsDefault,
Destination: &options.setAsDefault,
EnvVars: []string{"DOCKER_SET_AS_DEFAULT"},
Hidden: true,
Destination: &options.Socket,
EnvVars: []string{"RUNTIME_SOCKET", "DOCKER_SOCKET"},
},
&cli.StringFlag{
Name: "restart-mode",
Usage: "Specify how docker should be restarted; If 'none' is selected it will not be restarted [signal | none]",
Usage: "Specify how docker should be restarted; If 'none' is selected it will not be restarted [signal | systemd | none ]",
Value: defaultRestartMode,
Destination: &options.restartMode,
EnvVars: []string{"DOCKER_RESTART_MODE"},
Destination: &options.RestartMode,
EnvVars: []string{"RUNTIME_RESTART_MODE", "DOCKER_RESTART_MODE"},
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting docker using systemd",
Value: defaultHostRootMount,
Destination: &options.HostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"},
// Restart using systemd is currently not supported.
// We hide this option for the time being.
Hidden: true,
},
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"nvidia-runtime-name", "runtime-class"},
Usage: "Specify the name of the `nvidia` runtime. If set-as-default is selected, the runtime is used as the default runtime.",
Value: defaultRuntimeName,
Destination: &options.RuntimeName,
EnvVars: []string{"NVIDIA_RUNTIME_NAME", "DOCKER_RUNTIME_NAME"},
},
&cli.StringFlag{
Name: "nvidia-runtime-dir",
Aliases: []string{"runtime-dir"},
Usage: "The path where the nvidia-container-runtime binaries are located. If this is not specified, the first argument will be used instead",
Destination: &options.RuntimeDir,
EnvVars: []string{"NVIDIA_RUNTIME_DIR"},
},
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set the `nvidia` runtime as the default runtime. If --runtime-name is specified as `nvidia-experimental` the experimental runtime is set as the default runtime instead",
Value: defaultSetAsDefault,
Destination: &options.SetAsDefault,
EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "DOCKER_SET_AS_DEFAULT"},
Hidden: true,
},
}
@@ -165,28 +165,16 @@ func main() {
func Setup(c *cli.Context, o *options) error {
log.Infof("Starting 'setup' for %v", c.App.Name)
runtimeDir, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
o.runtimeDir = runtimeDir
cfg, err := docker.New(
docker.WithPath(o.config),
docker.WithPath(o.Config),
)
if err != nil {
return fmt.Errorf("unable to load config: %v", err)
}
err = UpdateConfig(cfg, o)
err = o.Configure(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
log.Infof("Flushing docker config to %v", o.config)
_, err = cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
return fmt.Errorf("unable to configure docker: %v", err)
}
err = RestartDocker(o)
@@ -203,30 +191,16 @@ func Setup(c *cli.Context, o *options) error {
func Cleanup(c *cli.Context, o *options) error {
log.Infof("Starting 'cleanup' for %v", c.App.Name)
_, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
cfg, err := docker.New(
docker.WithPath(o.config),
docker.WithPath(o.Config),
)
if err != nil {
return fmt.Errorf("unable to load config: %v", err)
}
err = RevertConfig(cfg, o)
err = o.Unconfigure(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
log.Infof("Flushing docker config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
return fmt.Errorf("unable to unconfigure docker: %v", err)
}
err = RestartDocker(o)
@@ -239,69 +213,9 @@ func Cleanup(c *cli.Context, o *options) error {
return nil
}
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context) (string, error) {
args := c.Args()
log.Infof("Parsing arguments: %v", args.Slice())
if args.Len() != 1 {
return "", fmt.Errorf("incorrect number of arguments")
}
runtimeDir := args.Get(0)
log.Infof("Successfully parsed arguments")
return runtimeDir, nil
}
// UpdateConfig updates the docker config to include the nvidia runtimes
func UpdateConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeName),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for name, runtime := range runtimes {
err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("failed to update runtime %q: %v", name, err)
}
}
return nil
}
// RevertConfig reverts the docker config to remove the nvidia runtime
func RevertConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeName),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for name := range runtimes {
err := cfg.RemoveRuntime(name)
if err != nil {
return fmt.Errorf("failed to remove runtime %q: %v", name, err)
}
}
return nil
}
// RestartDocker restarts docker depending on the value of restartModeFlag
func RestartDocker(o *options) error {
switch o.restartMode {
case restartModeNone:
log.Warnf("Skipping sending signal to docker due to --restart-mode=%v", o.restartMode)
case restartModeSignal:
err := SignalDocker(o.socket)
if err != nil {
return fmt.Errorf("unable to signal docker: %v", err)
}
default:
return fmt.Errorf("invalid restart mode specified: %v", o.restartMode)
}
return nil
return o.Restart("docker", SignalDocker)
}
// SignalDocker sends a SIGHUP signal to docker daemon

View File

@@ -21,6 +21,7 @@ import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/docker"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
"github.com/stretchr/testify/require"
)
@@ -56,14 +57,16 @@ func TestUpdateConfigDefaultRuntime(t *testing.T) {
for i, tc := range testCases {
o := &options{
setAsDefault: tc.setAsDefault,
runtimeName: tc.runtimeName,
runtimeDir: runtimeDir,
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
}
config := docker.Config(map[string]interface{}{})
err := UpdateConfig(&config, o)
err := o.UpdateConfig(&config)
require.NoError(t, err, "%d: %v", i, tc)
defaultRuntimeName := config["default-runtime"]
@@ -314,13 +317,15 @@ func TestUpdateConfig(t *testing.T) {
}
for i, tc := range testCases {
options := &options{
setAsDefault: tc.setAsDefault,
runtimeName: tc.runtimeName,
runtimeDir: runtimeDir,
o := &options{
Options: container.Options{
RuntimeName: tc.runtimeName,
RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
}
err := UpdateConfig(&tc.config, options)
err := o.UpdateConfig(&tc.config)
require.NoError(t, err, "%d: %v", i, tc)
configContent, err := json.MarshalIndent(tc.config, "", " ")
@@ -457,7 +462,8 @@ func TestRevertConfig(t *testing.T) {
}
for i, tc := range testCases {
err := RevertConfig(&tc.config, &options{})
o := &options{}
err := o.RevertConfig(&tc.config)
require.NoError(t, err, "%d: %v", i, tc)

View File

@@ -23,7 +23,7 @@ import (
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
@@ -305,7 +305,7 @@ func Install(cli *cli.Context, opts *options) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container CLI: %v", err))
}
_, err = installRuntimeHook(opts.toolkitRoot, toolkitConfigPath)
nvidiaContainerRuntimeHookPath, err := installRuntimeHook(opts.toolkitRoot, toolkitConfigPath)
if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container runtime hook: %v", err)
} else if err != nil {
@@ -319,7 +319,7 @@ func Install(cli *cli.Context, opts *options) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA Container Toolkit CLI: %v", err))
}
err = installToolkitConfig(cli, toolkitConfigPath, nvidiaContainerCliExecutable, nvidiaCTKPath, opts)
err = installToolkitConfig(cli, toolkitConfigPath, nvidiaContainerCliExecutable, nvidiaCTKPath, nvidiaContainerRuntimeHookPath, opts)
if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container toolkit config: %v", err)
} else if err != nil {
@@ -379,7 +379,7 @@ func installLibrary(libName string, toolkitRoot string) error {
// installToolkitConfig installs the config file for the NVIDIA container toolkit ensuring
// that the settings are updated to match the desired install and nvidia driver directories.
func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContainerCliExecutablePath string, nvidiaCTKPath string, opts *options) error {
func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContainerCliExecutablePath string, nvidiaCTKPath string, nvidaContainerRuntimeHookPath string, opts *options) error {
log.Infof("Installing NVIDIA container toolkit config '%v'", toolkitConfigPath)
config, err := loadConfig(nvidiaContainerToolkitConfigSource)
@@ -410,6 +410,7 @@ func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContai
// Set nvidia-ctk options
"nvidia-ctk.path": nvidiaCTKPath,
// Set the nvidia-container-runtime-hook options
"nvidia-container-runtime-hook.path": nvidaContainerRuntimeHookPath,
"nvidia-container-runtime-hook.skip-mode-detection": opts.ContainerRuntimeHookSkipModeDetection,
}
for key, value := range configValues {
@@ -682,11 +683,13 @@ func generateCDISpec(opts *options, nvidiaCTKPath string) error {
}
log.Infof("Creating control device nodes at %v", opts.DriverRootCtrPath)
s, err := system.New()
devices, err := nvdevices.New(
nvdevices.WithDevRoot(opts.DriverRootCtrPath),
)
if err != nil {
return fmt.Errorf("failed to create library: %v", err)
}
if err := s.CreateNVIDIAControlDeviceNodesAt(opts.DriverRootCtrPath); err != nil {
if err := devices.CreateNVIDIAControlDevices(); err != nil {
return fmt.Errorf("failed to create control device nodes: %v", err)
}

View File

@@ -1,2 +1,2 @@
toml.test
/toml.test
/toml-test

View File

@@ -1 +0,0 @@
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0).

View File

@@ -1,6 +1,5 @@
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages.
reflection interface similar to Go's standard library `json` and `xml` packages.
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0).
@@ -10,7 +9,7 @@ See the [releases page](https://github.com/BurntSushi/toml/releases) for a
changelog; this information is also in the git tag annotations (e.g. `git show
v0.4.0`).
This library requires Go 1.13 or newer; install it with:
This library requires Go 1.13 or newer; add it to your go.mod with:
% go get github.com/BurntSushi/toml@latest
@@ -19,16 +18,7 @@ It also comes with a TOML validator CLI tool:
% go install github.com/BurntSushi/toml/cmd/tomlv@latest
% tomlv some-toml-file.toml
### Testing
This package passes all tests in [toml-test] for both the decoder and the
encoder.
[toml-test]: https://github.com/BurntSushi/toml-test
### Examples
This package works similar to how the Go standard library handles XML and JSON.
Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys and
values:
@@ -40,7 +30,7 @@ Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
Which can be decoded with:
```go
type Config struct {
@@ -48,20 +38,15 @@ type Config struct {
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
DOB time.Time
}
```
And then decoded with:
```go
var conf Config
err := toml.Decode(tomlData, &conf)
// handle error
_, err := toml.Decode(tomlData, &conf)
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
You can also use struct tags if your struct field name doesn't map to a TOML key
value directly:
```toml
some_key_NAME = "wat"
@@ -73,139 +58,63 @@ type TOML struct {
}
```
Beware that like other most other decoders **only exported fields** are
considered when encoding and decoding; private fields are silently ignored.
Beware that like other decoders **only exported fields** are considered when
encoding and decoding; private fields are silently ignored.
### Using the `Marshaler` and `encoding.TextUnmarshaler` interfaces
Here's an example that automatically parses duration strings into
`time.Duration` values:
Here's an example that automatically parses values in a `mail.Address`:
```toml
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
contacts = [
"Donald Duck <donald@duckburg.com>",
"Scrooge McDuck <scrooge@duckburg.com>",
]
```
Which can be decoded with:
Can be decoded with:
```go
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
// Create address type which satisfies the encoding.TextUnmarshaler interface.
type address struct {
*mail.Address
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
func (a *address) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
a.Address, err = mail.ParseAddress(string(text))
return err
}
// Decode it.
func decode() {
blob := `
contacts = [
"Donald Duck <donald@duckburg.com>",
"Scrooge McDuck <scrooge@duckburg.com>",
]
`
var contacts struct {
Contacts []address
}
_, err := toml.Decode(blob, &contacts)
if err != nil {
log.Fatal(err)
}
for _, c := range contacts.Contacts {
fmt.Printf("%#v\n", c.Address)
}
// Output:
// &mail.Address{Name:"Donald Duck", Address:"donald@duckburg.com"}
// &mail.Address{Name:"Scrooge McDuck", Address:"scrooge@duckburg.com"}
}
```
To target TOML specifically you can implement `UnmarshalTOML` TOML interface in
a similar way.
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_example/example.{go,toml}`.
See the [`_example/`](/_example) directory for a more complex example.

View File

@@ -1,14 +1,18 @@
package toml
import (
"bytes"
"encoding"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"math"
"os"
"reflect"
"strconv"
"strings"
"time"
)
// Unmarshaler is the interface implemented by objects that can unmarshal a
@@ -17,16 +21,35 @@ type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
// Unmarshal decodes the contents of data in TOML format into a pointer v.
//
// See [Decoder] for a description of the decoding process.
func Unmarshal(data []byte, v interface{}) error {
_, err := NewDecoder(bytes.NewReader(data)).Decode(v)
return err
}
// Decode the TOML data in to the pointer v.
//
// See [Decoder] for a description of the decoding process.
func Decode(data string, v interface{}) (MetaData, error) {
return NewDecoder(strings.NewReader(data)).Decode(v)
}
// DecodeFile reads the contents of a file and decodes it with [Decode].
func DecodeFile(path string, v interface{}) (MetaData, error) {
fp, err := os.Open(path)
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
//
// This type can be used for any value, which will cause decoding to be delayed.
// You can use the PrimitiveDecode() function to "manually" decode these values.
// You can use [PrimitiveDecode] to "manually" decode these values.
//
// NOTE: The underlying representation of a `Primitive` value is subject to
// change. Do not rely on it.
@@ -42,36 +65,22 @@ type Primitive struct {
// The significand precision for float32 and float64 is 24 and 53 bits; this is
// the range a natural number can be stored in a float without loss of data.
const (
maxSafeFloat32Int = 16777215 // 2^24-1
maxSafeFloat64Int = 9007199254740991 // 2^53-1
maxSafeFloat32Int = 16777215 // 2^24-1
maxSafeFloat64Int = int64(9007199254740991) // 2^53-1
)
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decoder decodes TOML data.
//
// TOML tables correspond to Go structs or maps (dealer's choice they can be
// used interchangeably).
// TOML tables correspond to Go structs or maps; they can be used
// interchangeably, but structs offer better type safety.
//
// TOML table arrays correspond to either a slice of structs or a slice of maps.
//
// TOML datetimes correspond to Go time.Time values. Local datetimes are parsed
// in the local timezone.
// TOML datetimes correspond to [time.Time]. Local datetimes are parsed in the
// local timezone.
//
// [time.Duration] types are treated as nanoseconds if the TOML value is an
// integer, or they're parsed with time.ParseDuration() if they're strings.
//
// All other TOML types (float, string, int, bool and array) correspond to the
// obvious Go types.
@@ -80,9 +89,9 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
// interface, in which case any primitive TOML value (floats, strings, integers,
// booleans, datetimes) will be converted to a []byte and given to the value's
// UnmarshalText method. See the Unmarshaler example for a demonstration with
// time duration strings.
// email addresses.
//
// Key mapping
// ### Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go struct.
// The special `toml` struct tag can be used to map TOML keys to struct fields
@@ -109,6 +118,7 @@ func NewDecoder(r io.Reader) *Decoder {
var (
unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
primitiveType = reflect.TypeOf((*Primitive)(nil)).Elem()
)
// Decode TOML data in to the pointer `v`.
@@ -120,10 +130,10 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
s = "%v"
}
return MetaData{}, e("cannot decode to non-pointer "+s, reflect.TypeOf(v))
return MetaData{}, fmt.Errorf("toml: cannot decode to non-pointer "+s, reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("cannot decode to nil value of %q", reflect.TypeOf(v))
return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v))
}
// Check if this is a supported type: struct, map, interface{}, or something
@@ -133,7 +143,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) &&
!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) {
return MetaData{}, e("cannot decode to type %s", rt)
return MetaData{}, fmt.Errorf("toml: cannot decode to type %s", rt)
}
// TODO: parser should read from io.Reader? Or at the very least, make it
@@ -150,30 +160,29 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
md := MetaData{
mapping: p.mapping,
types: p.types,
keyInfo: p.keyInfo,
keys: p.ordered,
decoded: make(map[string]struct{}, len(p.ordered)),
context: nil,
data: data,
}
return md, md.unify(p.mapping, rv)
}
// Decode the TOML data in to the pointer v.
// PrimitiveDecode is just like the other Decode* functions, except it decodes a
// TOML value that has already been parsed. Valid primitive values can *only* be
// obtained from values filled by the decoder functions, including this method.
// (i.e., v may contain more [Primitive] values.)
//
// See the documentation on Decoder for a description of the decoding process.
func Decode(data string, v interface{}) (MetaData, error) {
return NewDecoder(strings.NewReader(data)).Decode(v)
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at path and decode it for you.
func DecodeFile(path string, v interface{}) (MetaData, error) {
fp, err := os.Open(path)
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
// Meta data for primitive values is included in the meta data returned by the
// Decode* functions with one exception: keys returned by the Undecoded method
// will only reflect keys that were decoded. Namely, any keys hidden behind a
// Primitive will be considered undecoded. Executing this method will update the
// undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// unify performs a sort of type unification based on the structure of `rv`,
@@ -184,7 +193,7 @@ func DecodeFile(path string, v interface{}) (MetaData, error) {
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
// TODO: #76 would make this superfluous after implemented.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
if rv.Type() == primitiveType {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
@@ -196,17 +205,14 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
rvi := rv.Interface()
if v, ok := rvi.(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
if v, ok := rvi.(encoding.TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// TODO:
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML hash or
@@ -217,7 +223,6 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
@@ -243,15 +248,14 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("unsupported type %s", rv.Type())
if rv.NumMethod() > 0 { // Only support empty interfaces are supported.
return md.e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32, reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("unsupported type %s", rv.Kind())
return md.e("unsupported type %s", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
@@ -260,7 +264,7 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
if mapping == nil {
return nil
}
return e("type mismatch for %s: expected table but found %T",
return md.e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
}
@@ -286,13 +290,14 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = struct{}{}
md.context = append(md.context, key)
err := md.unify(datum, subv)
if err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
return e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
return md.e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
}
}
}
@@ -300,10 +305,10 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
if k := rv.Type().Key().Kind(); k != reflect.String {
return fmt.Errorf(
"toml: cannot decode to a map with non-string key type (%s in %q)",
k, rv.Type())
keyType := rv.Type().Key().Kind()
if keyType != reflect.String && keyType != reflect.Interface {
return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)",
keyType, rv.Type())
}
tmap, ok := mapping.(map[string]interface{})
@@ -321,13 +326,22 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
md.context = append(md.context, k)
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
err := md.unify(v, indirect(rvval))
if err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey := indirect(reflect.New(rv.Type().Key()))
rvkey.SetString(k)
switch keyType {
case reflect.Interface:
rvkey.Set(reflect.ValueOf(k))
case reflect.String:
rvkey.SetString(k)
}
rv.SetMapIndex(rvkey, rvval)
}
return nil
@@ -342,7 +356,7 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
return md.badtype("slice", data)
}
if l := datav.Len(); l != rv.Len() {
return e("expected array length %d; got TOML array of length %d", rv.Len(), l)
return md.e("expected array length %d; got TOML array of length %d", rv.Len(), l)
}
return md.unifySliceArray(datav, rv)
}
@@ -375,6 +389,18 @@ func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
_, ok := rv.Interface().(json.Number)
if ok {
if i, ok := data.(int64); ok {
rv.SetString(strconv.FormatInt(i, 10))
} else if f, ok := data.(float64); ok {
rv.SetString(strconv.FormatFloat(f, 'f', -1, 64))
} else {
return md.badtype("string", data)
}
return nil
}
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
@@ -383,11 +409,13 @@ func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
rvk := rv.Kind()
if num, ok := data.(float64); ok {
switch rv.Kind() {
switch rvk {
case reflect.Float32:
if num < -math.MaxFloat32 || num > math.MaxFloat32 {
return e("value %f is out of range for float32", num)
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
fallthrough
case reflect.Float64:
@@ -399,20 +427,11 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
}
if num, ok := data.(int64); ok {
switch rv.Kind() {
case reflect.Float32:
if num < -maxSafeFloat32Int || num > maxSafeFloat32Int {
return e("value %d is out of range for float32", num)
}
fallthrough
case reflect.Float64:
if num < -maxSafeFloat64Int || num > maxSafeFloat64Int {
return e("value %d is out of range for float64", num)
}
rv.SetFloat(float64(num))
default:
panic("bug")
if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) ||
(rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
rv.SetFloat(float64(num))
return nil
}
@@ -420,50 +439,46 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("value %d is out of range for int8", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("value %d is out of range for int16", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("value %d is out of range for int32", num)
}
_, ok := rv.Interface().(time.Duration)
if ok {
// Parse as string duration, and fall back to regular integer parsing
// (as nanosecond) if this is not a string.
if s, ok := data.(string); ok {
dur, err := time.ParseDuration(s)
if err != nil {
return md.parseErr(errParseDuration{s})
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("value %d is out of range for uint8", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("value %d is out of range for uint16", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("value %d is out of range for uint32", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
rv.SetInt(int64(dur))
return nil
}
return nil
}
return md.badtype("integer", data)
num, ok := data.(int64)
if !ok {
return md.badtype("integer", data)
}
rvk := rv.Kind()
switch {
case rvk >= reflect.Int && rvk <= reflect.Int64:
if (rvk == reflect.Int8 && (num < math.MinInt8 || num > math.MaxInt8)) ||
(rvk == reflect.Int16 && (num < math.MinInt16 || num > math.MaxInt16)) ||
(rvk == reflect.Int32 && (num < math.MinInt32 || num > math.MaxInt32)) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
rv.SetInt(num)
case rvk >= reflect.Uint && rvk <= reflect.Uint64:
unum := uint64(num)
if rvk == reflect.Uint8 && (num < 0 || unum > math.MaxUint8) ||
rvk == reflect.Uint16 && (num < 0 || unum > math.MaxUint16) ||
rvk == reflect.Uint32 && (num < 0 || unum > math.MaxUint32) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
rv.SetUint(unum)
default:
panic("unreachable")
}
return nil
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
@@ -488,7 +503,7 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
return err
}
s = string(text)
case TextMarshaler:
case encoding.TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
@@ -514,7 +529,30 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
}
func (md *MetaData) badtype(dst string, data interface{}) error {
return e("incompatible types: TOML key %q has type %T; destination has type %s", md.context, data, dst)
return md.e("incompatible types: TOML value has type %T; destination has type %s", data, dst)
}
func (md *MetaData) parseErr(err error) error {
k := md.context.String()
return ParseError{
LastKey: k,
Position: md.keyInfo[k].pos,
Line: md.keyInfo[k].pos.Line,
err: err,
input: string(md.data),
}
}
func (md *MetaData) e(format string, args ...interface{}) error {
f := "toml: "
if len(md.context) > 0 {
f = fmt.Sprintf("toml: (last key %q): ", md.context)
p := md.keyInfo[md.context.String()].pos
if p.Line > 0 {
f = fmt.Sprintf("toml: line %d (last key %q): ", p.Line, md.context)
}
}
return fmt.Errorf(f+format, args...)
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
@@ -533,7 +571,11 @@ func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanSet() {
pv := v.Addr()
if _, ok := pv.Interface().(encoding.TextUnmarshaler); ok {
pvi := pv.Interface()
if _, ok := pvi.(encoding.TextUnmarshaler); ok {
return pv
}
if _, ok := pvi.(Unmarshaler); ok {
return pv
}
}
@@ -549,12 +591,12 @@ func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
rvi := rv.Interface()
if _, ok := rvi.(encoding.TextUnmarshaler); ok {
return true
}
if _, ok := rvi.(Unmarshaler); ok {
return true
}
return false
}
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}

View File

@@ -7,8 +7,8 @@ import (
"io/fs"
)
// DecodeFS is just like Decode, except it will automatically read the contents
// of the file at `path` from a fs.FS instance.
// DecodeFS reads the contents of a file from [fs.FS] and decodes it with
// [Decode].
func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) {
fp, err := fsys.Open(path)
if err != nil {

View File

@@ -1,13 +1,11 @@
/*
Package toml implements decoding and encoding of TOML files.
This package supports TOML v1.0.0, as listed on https://toml.io
There is also support for delaying decoding with the Primitive type, and
querying the set of keys in a TOML document with the MetaData type.
The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator,
and can be used to verify if TOML document is valid. It can also be used to
print the type of each key.
*/
// Package toml implements decoding and encoding of TOML files.
//
// This package supports TOML v1.0.0, as specified at https://toml.io
//
// There is also support for delaying decoding with the Primitive type, and
// querying the set of keys in a TOML document with the MetaData type.
//
// The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator,
// and can be used to verify if TOML document is valid. It can also be used to
// print the type of each key.
package toml

View File

@@ -3,6 +3,7 @@ package toml
import (
"bufio"
"encoding"
"encoding/json"
"errors"
"fmt"
"io"
@@ -63,6 +64,12 @@ var dblQuotedReplacer = strings.NewReplacer(
"\x7f", `\u007f`,
)
var (
marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
)
// Marshaler is the interface implemented by types that can marshal themselves
// into valid TOML.
type Marshaler interface {
@@ -72,9 +79,12 @@ type Marshaler interface {
// Encoder encodes a Go to a TOML document.
//
// The mapping between Go values and TOML values should be precisely the same as
// for the Decode* functions.
// for [Decode].
//
// The toml.Marshaler and encoder.TextMarshaler interfaces are supported to
// time.Time is encoded as a RFC 3339 string, and time.Duration as its string
// representation.
//
// The [Marshaler] and [encoding.TextMarshaler] interfaces are supported to
// encoding the value as custom TOML.
//
// If you want to write arbitrary binary data then you will need to use
@@ -85,6 +95,17 @@ type Marshaler interface {
//
// Go maps will be sorted alphabetically by key for deterministic output.
//
// The toml struct tag can be used to provide the key name; if omitted the
// struct field name will be used. If the "omitempty" option is present the
// following value will be skipped:
//
// - arrays, slices, maps, and string with len of 0
// - struct with all zero values
// - bool false
//
// If omitzero is given all int and float types with a value of 0 will be
// skipped.
//
// Encoding Go values without a corresponding TOML representation will return an
// error. Examples of this includes maps with non-string keys, slices with nil
// elements, embedded non-struct types, and nested slices containing maps or
@@ -109,7 +130,7 @@ func NewEncoder(w io.Writer) *Encoder {
}
}
// Encode writes a TOML representation of the Go value to the Encoder's writer.
// Encode writes a TOML representation of the Go value to the [Encoder]'s writer.
//
// An error is returned if the value given cannot be encoded to a valid TOML
// document.
@@ -136,18 +157,15 @@ func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case: time needs to be in ISO8601 format.
//
// Special case: if we can marshal the type to text, then we used that. This
// prevents the encoder for handling these types as generic structs (or
// whatever the underlying type of a TextMarshaler is).
switch t := rv.Interface().(type) {
case time.Time, encoding.TextMarshaler, Marshaler:
// If we can marshal the type to text, then we use that. This prevents the
// encoder for handling these types as generic structs (or whatever the
// underlying type of a TextMarshaler is).
switch {
case isMarshaler(rv):
enc.writeKeyValue(key, rv, false)
return
// TODO: #76 would make this superfluous after implemented.
case Primitive:
enc.encode(key, reflect.ValueOf(t.undecoded))
case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented.
enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded))
return
}
@@ -212,18 +230,44 @@ func (enc *Encoder) eElement(rv reflect.Value) {
if err != nil {
encPanic(err)
}
enc.writeQuoted(string(s))
if s == nil {
encPanic(errors.New("MarshalTOML returned nil and no error"))
}
enc.w.Write(s)
return
case encoding.TextMarshaler:
s, err := v.MarshalText()
if err != nil {
encPanic(err)
}
if s == nil {
encPanic(errors.New("MarshalText returned nil and no error"))
}
enc.writeQuoted(string(s))
return
case time.Duration:
enc.writeQuoted(v.String())
return
case json.Number:
n, _ := rv.Interface().(json.Number)
if n == "" { /// Useful zero value.
enc.w.WriteByte('0')
return
} else if v, err := n.Int64(); err == nil {
enc.eElement(reflect.ValueOf(v))
return
} else if v, err := n.Float64(); err == nil {
enc.eElement(reflect.ValueOf(v))
return
}
encPanic(fmt.Errorf("unable to convert %q to int64 or float64", n))
}
switch rv.Kind() {
case reflect.Ptr:
enc.eElement(rv.Elem())
return
case reflect.String:
enc.writeQuoted(rv.String())
case reflect.Bool:
@@ -259,7 +303,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
case reflect.Interface:
enc.eElement(rv.Elem())
default:
encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface()))
encPanic(fmt.Errorf("unexpected type: %T", rv.Interface()))
}
}
@@ -280,7 +324,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
elem := eindirect(rv.Index(i))
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
@@ -294,7 +338,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
trv := eindirect(rv.Index(i))
if isNil(trv) {
continue
}
@@ -319,7 +363,7 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) {
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) {
switch rv := eindirect(rv); rv.Kind() {
switch rv.Kind() {
case reflect.Map:
enc.eMap(key, rv, inline)
case reflect.Struct:
@@ -341,7 +385,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) {
if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
@@ -351,7 +395,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var writeMapKeys = func(mapKeys []string, trailC bool) {
sort.Strings(mapKeys)
for i, mapKey := range mapKeys {
val := rv.MapIndex(reflect.ValueOf(mapKey))
val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey)))
if isNil(val) {
continue
}
@@ -379,6 +423,13 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
const is32Bit = (32 << (^uint(0) >> 63)) == 32
func pointerTo(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
return pointerTo(t.Elem())
}
return t
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table then all keys under it will be in that
@@ -395,31 +446,25 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
isEmbed := f.Anonymous && pointerTo(f.Type).Kind() == reflect.Struct
if f.PkgPath != "" && !isEmbed { /// Skip unexported fields.
continue
}
opts := getOptions(f.Tag)
if opts.skip {
continue
}
frv := rv.Field(i)
frv := eindirect(rv.Field(i))
// Treat anonymous struct fields with tag names as though they are
// not anonymous, like encoding/json does.
//
// Non-struct anonymous fields use the normal encoding logic.
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
if getOptions(f.Tag).name == "" {
addFields(t, frv, append(start, f.Index...))
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), append(start, f.Index...))
}
continue
}
if isEmbed {
if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct {
addFields(frv.Type(), frv, append(start, f.Index...))
continue
}
}
@@ -445,7 +490,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
writeFields := func(fields [][]int) {
for _, fieldIndex := range fields {
fieldType := rt.FieldByIndex(fieldIndex)
fieldVal := rv.FieldByIndex(fieldIndex)
fieldVal := eindirect(rv.FieldByIndex(fieldIndex))
if isNil(fieldVal) { /// Don't write anything for nil fields.
continue
@@ -459,7 +504,8 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && isEmpty(fieldVal) {
if opts.omitempty && enc.isEmpty(fieldVal) {
continue
}
if opts.omitzero && isZero(fieldVal) {
@@ -498,6 +544,21 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
if rv.Kind() == reflect.Struct {
if rv.Type() == timeType {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
}
if isMarshaler(rv) {
return tomlString
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
@@ -509,7 +570,7 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
if isTableArray(rv) {
return tomlArrayHash
}
return tomlArray
@@ -519,67 +580,35 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
if _, ok := rv.Interface().(time.Time); ok {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
default:
if isMarshaler(rv) {
return tomlString
}
encPanic(errors.New("unsupported type: " + rv.Kind().String()))
panic("unreachable")
}
}
func isMarshaler(rv reflect.Value) bool {
switch rv.Interface().(type) {
case encoding.TextMarshaler:
return true
case Marshaler:
return true
}
// Someone used a pointer receiver: we can make it work for pointer values.
if rv.CanAddr() {
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok {
return true
}
if _, ok := rv.Addr().Interface().(Marshaler); ok {
return true
}
}
return false
return rv.Type().Implements(marshalText) || rv.Type().Implements(marshalToml)
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
// isTableArray reports if all entries in the array or slice are a table.
func isTableArray(arr reflect.Value) bool {
if isNil(arr) || !arr.IsValid() || arr.Len() == 0 {
return false
}
/// Don't allow nil.
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
if tomlTypeOfGo(rv.Index(i)) == nil {
ret := true
for i := 0; i < arr.Len(); i++ {
tt := tomlTypeOfGo(eindirect(arr.Index(i)))
// Don't allow nil.
if tt == nil {
encPanic(errArrayNilElement)
}
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
if ret && !typeEqual(tomlHash, tt) {
ret = false
}
}
return firstType
return ret
}
type tagOptions struct {
@@ -620,10 +649,26 @@ func isZero(rv reflect.Value) bool {
return false
}
func isEmpty(rv reflect.Value) bool {
func (enc *Encoder) isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return rv.Len() == 0
case reflect.Struct:
if rv.Type().Comparable() {
return reflect.Zero(rv.Type()).Interface() == rv.Interface()
}
// Need to also check if all the fields are empty, otherwise something
// like this with uncomparable types will always return true:
//
// type a struct{ field b }
// type b struct{ s []string }
// s := a{field: b{s: []string{"AAA"}}}
for i := 0; i < rv.NumField(); i++ {
if !enc.isEmpty(rv.Field(i)) {
return false
}
}
return true
case reflect.Bool:
return !rv.Bool()
}
@@ -638,16 +683,15 @@ func (enc *Encoder) newline() {
// Write a key/value pair:
//
// key = <any value>
// key = <any value>
//
// This is also used for "k = v" in inline tables; so something like this will
// be written in three calls:
//
// ┌────────────────────┐
// │ ┌───┐ ┌────┐│
// v v v v vv
// key = {k = v, k2 = v2}
//
//───────────────────┐
// │ ┌───┐ ┌────┐│
// v v v v vv
// key = {k = 1, k2 = 2}
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
if len(key) == 0 {
encPanic(errNoKey)
@@ -675,13 +719,25 @@ func encPanic(err error) {
panic(tomlEncodeError{err})
}
// Resolve any level of pointers to the actual value (e.g. **string → string).
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
if v.Kind() != reflect.Ptr && v.Kind() != reflect.Interface {
if isMarshaler(v) {
return v
}
if v.CanAddr() { /// Special case for marshalers; see #358.
if pv := v.Addr(); isMarshaler(pv) {
return pv
}
}
return v
}
if v.IsNil() {
return v
}
return eindirect(v.Elem())
}
func isNil(rv reflect.Value) bool {

View File

@@ -5,57 +5,60 @@ import (
"strings"
)
// ParseError is returned when there is an error parsing the TOML syntax.
//
// For example invalid syntax, duplicate keys, etc.
// ParseError is returned when there is an error parsing the TOML syntax such as
// invalid syntax, duplicate keys, etc.
//
// In addition to the error message itself, you can also print detailed location
// information with context by using ErrorWithLocation():
// information with context by using [ErrorWithPosition]:
//
// toml: error: Key 'fruit' was already created and cannot be used as an array.
// toml: error: Key 'fruit' was already created and cannot be used as an array.
//
// At line 4, column 2-7:
// At line 4, column 2-7:
//
// 2 | fruit = []
// 3 |
// 4 | [[fruit]] # Not allowed
// ^^^^^
// 2 | fruit = []
// 3 |
// 4 | [[fruit]] # Not allowed
// ^^^^^
//
// Furthermore, the ErrorWithUsage() can be used to print the above with some
// more detailed usage guidance:
// [ErrorWithUsage] can be used to print the above with some more detailed usage
// guidance:
//
// toml: error: newlines not allowed within inline tables
// toml: error: newlines not allowed within inline tables
//
// At line 1, column 18:
// At line 1, column 18:
//
// 1 | x = [{ key = 42 #
// ^
// 1 | x = [{ key = 42 #
// ^
//
// Error help:
// Error help:
//
// Inline tables must always be on a single line:
// Inline tables must always be on a single line:
//
// table = {key = 42, second = 43}
// table = {key = 42, second = 43}
//
// It is invalid to split them over multiple lines like so:
// It is invalid to split them over multiple lines like so:
//
// # INVALID
// table = {
// key = 42,
// second = 43
// }
// # INVALID
// table = {
// key = 42,
// second = 43
// }
//
// Use regular for this:
// Use regular for this:
//
// [table]
// key = 42
// second = 43
// [table]
// key = 42
// second = 43
type ParseError struct {
Message string // Short technical message.
Usage string // Longer message with usage guidance; may be blank.
Position Position // Position of the error
LastKey string // Last parsed key, may be blank.
Line int // Line the error occurred. Deprecated: use Position.
// Line the error occurred.
//
// Deprecated: use [Position].
Line int
err error
input string
@@ -83,7 +86,7 @@ func (pe ParseError) Error() string {
// ErrorWithUsage() returns the error with detailed location context.
//
// See the documentation on ParseError.
// See the documentation on [ParseError].
func (pe ParseError) ErrorWithPosition() string {
if pe.input == "" { // Should never happen, but just in case.
return pe.Error()
@@ -124,13 +127,17 @@ func (pe ParseError) ErrorWithPosition() string {
// ErrorWithUsage() returns the error with detailed location context and usage
// guidance.
//
// See the documentation on ParseError.
// See the documentation on [ParseError].
func (pe ParseError) ErrorWithUsage() string {
m := pe.ErrorWithPosition()
if u, ok := pe.err.(interface{ Usage() string }); ok && u.Usage() != "" {
return m + "Error help:\n\n " +
strings.ReplaceAll(strings.TrimSpace(u.Usage()), "\n", "\n ") +
"\n"
lines := strings.Split(strings.TrimSpace(u.Usage()), "\n")
for i := range lines {
if lines[i] != "" {
lines[i] = " " + lines[i]
}
}
return m + "Error help:\n\n" + strings.Join(lines, "\n") + "\n"
}
return m
}
@@ -160,6 +167,11 @@ type (
errLexInvalidDate struct{ v string }
errLexInlineTableNL struct{}
errLexStringNL struct{}
errParseRange struct {
i interface{} // int or float
size string // "int64", "uint16", etc.
}
errParseDuration struct{ d string }
)
func (e errLexControl) Error() string {
@@ -179,6 +191,10 @@ func (e errLexInlineTableNL) Error() string { return "newlines not allowed withi
func (e errLexInlineTableNL) Usage() string { return usageInlineNewline }
func (e errLexStringNL) Error() string { return "strings cannot contain newlines" }
func (e errLexStringNL) Usage() string { return usageStringNewline }
func (e errParseRange) Error() string { return fmt.Sprintf("%v is out of range for %s", e.i, e.size) }
func (e errParseRange) Usage() string { return usageIntOverflow }
func (e errParseDuration) Error() string { return fmt.Sprintf("invalid duration: %q", e.d) }
func (e errParseDuration) Usage() string { return usageDuration }
const usageEscape = `
A '\' inside a "-delimited string is interpreted as an escape character.
@@ -227,3 +243,37 @@ Instead use """ or ''' to split strings over multiple lines:
string = """Hello,
world!"""
`
const usageIntOverflow = `
This number is too large; this may be an error in the TOML, but it can also be a
bug in the program that uses too small of an integer.
The maximum and minimum values are:
size │ lowest │ highest
───────┼────────────────┼──────────
int8 │ -128 │ 127
int16 │ -32,768 │ 32,767
int32 │ -2,147,483,648 │ 2,147,483,647
int64 │ -9.2 × 10¹⁷ │ 9.2 × 10¹⁷
uint8 │ 0 │ 255
uint16 │ 0 │ 65535
uint32 │ 0 │ 4294967295
uint64 │ 0 │ 1.8 × 10¹⁸
int refers to int32 on 32-bit systems and int64 on 64-bit systems.
`
const usageDuration = `
A duration must be as "number<unit>", without any spaces. Valid units are:
ns nanoseconds (billionth of a second)
us, µs microseconds (millionth of a second)
ms milliseconds (thousands of a second)
s seconds
m minutes
h hours
You can combine multiple units; for example "5m10s" for 5 minutes and 10
seconds.
`

View File

@@ -82,7 +82,7 @@ func (lx *lexer) nextItem() item {
return item
default:
lx.state = lx.state(lx)
//fmt.Printf(" STATE %-24s current: %-10q stack: %s\n", lx.state, lx.current(), lx.stack)
//fmt.Printf(" STATE %-24s current: %-10s stack: %s\n", lx.state, lx.current(), lx.stack)
}
}
}
@@ -128,6 +128,11 @@ func (lx lexer) getPos() Position {
}
func (lx *lexer) emit(typ itemType) {
// Needed for multiline strings ending with an incomplete UTF-8 sequence.
if lx.start > lx.pos {
lx.error(errLexUTF8{lx.input[lx.pos]})
return
}
lx.items <- item{typ: typ, pos: lx.getPos(), val: lx.current()}
lx.start = lx.pos
}
@@ -711,7 +716,17 @@ func lexMultilineString(lx *lexer) stateFn {
if lx.peek() == '"' {
/// Check if we already lexed 5 's; if so we have 6 now, and
/// that's just too many man!
if strings.HasSuffix(lx.current(), `"""""`) {
///
/// Second check is for the edge case:
///
/// two quotes allowed.
/// vv
/// """lol \""""""
/// ^^ ^^^---- closing three
/// escaped
///
/// But ugly, but it works
if strings.HasSuffix(lx.current(), `"""""`) && !strings.HasSuffix(lx.current(), `\"""""`) {
return lx.errorf(`unexpected '""""""'`)
}
lx.backup()
@@ -756,7 +771,7 @@ func lexRawString(lx *lexer) stateFn {
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'''" has already been consumed and
// a string. It assumes that the beginning ''' has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next()
@@ -802,8 +817,7 @@ func lexMultilineRawString(lx *lexer) stateFn {
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
if isNL(lx.next()) { /// \ escaping newline.
return lexMultilineString
}
lx.backup()

View File

@@ -12,10 +12,11 @@ import (
type MetaData struct {
context Key // Used only during decoding.
keyInfo map[string]keyInfo
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]struct{}
data []byte // Input file; for errors.
}
// IsDefined reports if the key exists in the TOML data.
@@ -50,8 +51,8 @@ func (md *MetaData) IsDefined(key ...string) bool {
// Type will return the empty string if given an empty key or a key that does
// not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
if typ, ok := md.types[Key(key).String()]; ok {
return typ.typeString()
if ki, ok := md.keyInfo[Key(key).String()]; ok {
return ki.tomlType.typeString()
}
return ""
}
@@ -70,7 +71,7 @@ func (md *MetaData) Keys() []Key {
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// This includes keys that haven't been decoded because of a [Primitive] value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
@@ -88,7 +89,7 @@ func (md *MetaData) Undecoded() []Key {
return undecoded
}
// Key represents any TOML key, including key groups. Use (MetaData).Keys to get
// Key represents any TOML key, including key groups. Use [MetaData.Keys] to get
// values of this type.
type Key []string

View File

@@ -16,12 +16,18 @@ type parser struct {
currentKey string // Base key name for everything except hashes.
pos Position // Current position in the TOML file.
ordered []Key // List of keys in the order that they appear in the TOML data.
ordered []Key // List of keys in the order that they appear in the TOML data.
keyInfo map[string]keyInfo // Map keyname → info about the TOML key.
mapping map[string]interface{} // Map keyname → key value.
types map[string]tomlType // Map keyname → TOML type.
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names").
}
type keyInfo struct {
pos Position
tomlType tomlType
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
@@ -57,8 +63,8 @@ func parse(data string) (p *parser, err error) {
}
p = &parser{
keyInfo: make(map[string]keyInfo),
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]struct{}),
@@ -74,6 +80,15 @@ func parse(data string) (p *parser, err error) {
return p, nil
}
func (p *parser) panicErr(it item, err error) {
panic(ParseError{
err: err,
Position: it.pos,
Line: it.pos.Len,
LastKey: p.current(),
})
}
func (p *parser) panicItemf(it item, format string, v ...interface{}) {
panic(ParseError{
Message: fmt.Sprintf(format, v...),
@@ -94,7 +109,7 @@ func (p *parser) panicf(format string, v ...interface{}) {
func (p *parser) next() item {
it := p.lx.nextItem()
//fmt.Printf("ITEM %-18s line %-3d │ %q\n", it.typ, it.line, it.val)
//fmt.Printf("ITEM %-18s line %-3d │ %q\n", it.typ, it.pos.Line, it.val)
if it.typ == itemError {
if it.err != nil {
panic(ParseError{
@@ -146,7 +161,7 @@ func (p *parser) topLevel(item item) {
p.assertEqual(itemTableEnd, name.typ)
p.addContext(key, false)
p.setType("", tomlHash)
p.setType("", tomlHash, item.pos)
p.ordered = append(p.ordered, key)
case itemArrayTableStart: // [[ .. ]]
name := p.nextPos()
@@ -158,7 +173,7 @@ func (p *parser) topLevel(item item) {
p.assertEqual(itemArrayTableEnd, name.typ)
p.addContext(key, true)
p.setType("", tomlArrayHash)
p.setType("", tomlArrayHash, item.pos)
p.ordered = append(p.ordered, key)
case itemKeyStart: // key = ..
outerContext := p.context
@@ -181,8 +196,9 @@ func (p *parser) topLevel(item item) {
}
/// Set value.
val, typ := p.value(p.next(), false)
p.set(p.currentKey, val, typ)
vItem := p.next()
val, typ := p.value(vItem, false)
p.set(p.currentKey, val, typ, vItem.pos)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
/// Remove the context we added (preserving any context from [tbl] lines).
@@ -220,7 +236,7 @@ func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) {
case itemString:
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it)
case itemMultilineString:
return p.replaceEscapes(it, stripFirstNewline(stripEscapedNewlines(it.val))), p.typeOfPrimitive(it)
return p.replaceEscapes(it, stripFirstNewline(p.stripEscapedNewlines(it.val))), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
@@ -266,7 +282,7 @@ func (p *parser) valueInteger(it item) (interface{}, tomlType) {
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
p.panicItemf(it, "Integer '%s' is out of the range of 64-bit signed integers.", it.val)
p.panicErr(it, errParseRange{i: it.val, size: "int64"})
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
@@ -304,7 +320,7 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) {
num, err := strconv.ParseFloat(val, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
p.panicItemf(it, "Float '%s' is out of the range of 64-bit IEEE-754 floating-point numbers.", it.val)
p.panicErr(it, errParseRange{i: it.val, size: "float64"})
} else {
p.panicItemf(it, "Invalid float value: %q", it.val)
}
@@ -343,9 +359,8 @@ func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
}
func (p *parser) valueArray(it item) (interface{}, tomlType) {
p.setType(p.currentKey, tomlArray)
p.setType(p.currentKey, tomlArray, it.pos)
// p.setType(p.currentKey, typ)
var (
types []tomlType
@@ -414,7 +429,7 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom
/// Set the value.
val, typ := p.value(p.next(), false)
p.set(p.currentKey, val, typ)
p.set(p.currentKey, val, typ, it.pos)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[p.currentKey] = val
@@ -533,9 +548,10 @@ func (p *parser) addContext(key Key, array bool) {
}
// set calls setValue and setType.
func (p *parser) set(key string, val interface{}, typ tomlType) {
func (p *parser) set(key string, val interface{}, typ tomlType, pos Position) {
p.setValue(key, val)
p.setType(key, typ)
p.setType(key, typ, pos)
}
// setValue sets the given key to the given value in the current context.
@@ -599,7 +615,7 @@ func (p *parser) setValue(key string, value interface{}) {
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
func (p *parser) setType(key string, typ tomlType, pos Position) {
keyContext := make(Key, 0, len(p.context)+1)
keyContext = append(keyContext, p.context...)
if len(key) > 0 { // allow type setting for hashes
@@ -611,7 +627,7 @@ func (p *parser) setType(key string, typ tomlType) {
if len(keyContext) == 0 {
keyContext = Key{""}
}
p.types[keyContext.String()] = typ
p.keyInfo[keyContext.String()] = keyInfo{tomlType: typ, pos: pos}
}
// Implicit keys need to be created when tables are implied in "a.b.c.d = 1" and
@@ -619,7 +635,7 @@ func (p *parser) setType(key string, typ tomlType) {
func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = struct{}{} }
func (p *parser) removeImplicit(key Key) { delete(p.implicits, key.String()) }
func (p *parser) isImplicit(key Key) bool { _, ok := p.implicits[key.String()]; return ok }
func (p *parser) isArray(key Key) bool { return p.types[key.String()] == tomlArray }
func (p *parser) isArray(key Key) bool { return p.keyInfo[key.String()].tomlType == tomlArray }
func (p *parser) addImplicitContext(key Key) {
p.addImplicit(key)
p.addContext(key, false)
@@ -647,7 +663,7 @@ func stripFirstNewline(s string) string {
}
// Remove newlines inside triple-quoted strings if a line ends with "\".
func stripEscapedNewlines(s string) string {
func (p *parser) stripEscapedNewlines(s string) string {
split := strings.Split(s, "\n")
if len(split) < 1 {
return s
@@ -679,6 +695,10 @@ func stripEscapedNewlines(s string) string {
continue
}
if i == len(split)-1 {
p.panicf("invalid escape: '\\ '")
}
split[i] = line[:len(line)-1] // Remove \
if len(split)-1 > i {
split[i+1] = strings.TrimLeft(split[i+1], " \t\r")
@@ -706,10 +726,8 @@ func (p *parser) replaceEscapes(it item, str string) string {
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case ' ', '\t':
p.panicItemf(it, "invalid escape: '\\%c'", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1

View File

@@ -191,6 +191,8 @@ type Linux struct {
IntelRdt *LinuxIntelRdt `json:"intelRdt,omitempty"`
// Personality contains configuration for the Linux personality syscall
Personality *LinuxPersonality `json:"personality,omitempty"`
// TimeOffsets specifies the offset for supporting time namespaces.
TimeOffsets map[string]LinuxTimeOffset `json:"timeOffsets,omitempty"`
}
// LinuxNamespace is the configuration for a Linux namespace
@@ -220,6 +222,8 @@ const (
UserNamespace LinuxNamespaceType = "user"
// CgroupNamespace for isolating cgroup hierarchies
CgroupNamespace LinuxNamespaceType = "cgroup"
// TimeNamespace for isolating the clocks
TimeNamespace LinuxNamespaceType = "time"
)
// LinuxIDMapping specifies UID/GID mappings
@@ -232,6 +236,14 @@ type LinuxIDMapping struct {
Size uint32 `json:"size"`
}
// LinuxTimeOffset specifies the offset for Time Namespace
type LinuxTimeOffset struct {
// Secs is the offset of clock (in secs) in the container
Secs int64 `json:"secs,omitempty"`
// Nanosecs is the additional offset for Secs (in nanosecs)
Nanosecs uint32 `json:"nanosecs,omitempty"`
}
// POSIXRlimit type and restrictions
type POSIXRlimit struct {
// Type of the rlimit to set
@@ -242,12 +254,13 @@ type POSIXRlimit struct {
Soft uint64 `json:"soft"`
}
// LinuxHugepageLimit structure corresponds to limiting kernel hugepages
// LinuxHugepageLimit structure corresponds to limiting kernel hugepages.
// Default to reservation limits if supported. Otherwise fallback to page fault limits.
type LinuxHugepageLimit struct {
// Pagesize is the hugepage size
// Format: "<size><unit-prefix>B' (e.g. 64KB, 2MB, 1GB, etc.)
// Pagesize is the hugepage size.
// Format: "<size><unit-prefix>B' (e.g. 64KB, 2MB, 1GB, etc.).
Pagesize string `json:"pageSize"`
// Limit is the limit of "hugepagesize" hugetlb usage
// Limit is the limit of "hugepagesize" hugetlb reservations (if supported) or usage.
Limit uint64 `json:"limit"`
}
@@ -319,6 +332,10 @@ type LinuxMemory struct {
DisableOOMKiller *bool `json:"disableOOMKiller,omitempty"`
// Enables hierarchical memory accounting
UseHierarchy *bool `json:"useHierarchy,omitempty"`
// CheckBeforeUpdate enables checking if a new memory limit is lower
// than the current usage during update, and if so, rejecting the new
// limit.
CheckBeforeUpdate *bool `json:"checkBeforeUpdate,omitempty"`
}
// LinuxCPU for Linux cgroup 'cpu' resource management
@@ -327,6 +344,9 @@ type LinuxCPU struct {
Shares *uint64 `json:"shares,omitempty"`
// CPU hardcap limit (in usecs). Allowed cpu time in a given period.
Quota *int64 `json:"quota,omitempty"`
// CPU hardcap burst limit (in usecs). Allowed accumulated cpu time additionally for burst in a
// given period.
Burst *uint64 `json:"burst,omitempty"`
// CPU period to be used for hardcapping (in usecs).
Period *uint64 `json:"period,omitempty"`
// How much time realtime scheduling may use (in usecs).
@@ -375,7 +395,7 @@ type LinuxResources struct {
Pids *LinuxPids `json:"pids,omitempty"`
// BlockIO restriction configuration
BlockIO *LinuxBlockIO `json:"blockIO,omitempty"`
// Hugetlb limit (in bytes)
// Hugetlb limits (in bytes). Default to reservation limits if supported.
HugepageLimits []LinuxHugepageLimit `json:"hugepageLimits,omitempty"`
// Network restriction configuration
Network *LinuxNetwork `json:"network,omitempty"`
@@ -645,6 +665,10 @@ const (
// LinuxSeccompFlagSpecAllow can be used to disable Speculative Store
// Bypass mitigation. (since Linux 4.17)
LinuxSeccompFlagSpecAllow LinuxSeccompFlag = "SECCOMP_FILTER_FLAG_SPEC_ALLOW"
// LinuxSeccompFlagWaitKillableRecv can be used to switch to the wait
// killable semantics. (since Linux 5.19)
LinuxSeccompFlagWaitKillableRecv LinuxSeccompFlag = "SECCOMP_FILTER_FLAG_WAIT_KILLABLE_RECV"
)
// Additional architectures permitted to be used for system calls

View File

@@ -6,12 +6,12 @@ const (
// VersionMajor is for an API incompatible changes
VersionMajor = 1
// VersionMinor is for functionality in a backwards-compatible manner
VersionMinor = 0
VersionMinor = 1
// VersionPatch is for backwards-compatible bug fixes
VersionPatch = 2
VersionPatch = 0
// VersionDev indicates development branch. Releases will be empty string.
VersionDev = "-dev"
VersionDev = "-rc.2"
)
// Version is the specification version that the package types support.

View File

@@ -1,8 +1,10 @@
package assert
import (
"bytes"
"fmt"
"reflect"
"time"
)
type CompareType int
@@ -30,6 +32,9 @@ var (
float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
bytesType = reflect.TypeOf([]byte{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@@ -299,6 +304,47 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
case reflect.Slice:
{
// We only care about the []byte type.
if !canConvert(obj1Value, bytesType) {
break
}
// []byte can be compared!
bytesObj1, ok := obj1.([]byte)
if !ok {
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
}
bytesObj2, ok := obj2.([]byte)
if !ok {
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
}
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
}
}
return compareEqual, false
@@ -310,7 +356,10 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
@@ -320,7 +369,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// Less asserts that the first element is less than the second
@@ -329,7 +381,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
@@ -339,7 +394,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}
// Positive asserts that the specified element is positive
@@ -347,8 +405,11 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
}
// Negative asserts that the specified element is negative
@@ -356,8 +417,11 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
}
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {

View File

@@ -0,0 +1,16 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@@ -0,0 +1,16 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@@ -123,6 +123,18 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
}
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
@@ -724,6 +736,16 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return WithinRange(t, actual, start, end, append([]interface{}{msg}, args...)...)
}
// YAMLEqf asserts that two YAML strings are equivalent.
func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {

View File

@@ -222,6 +222,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
return ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
@@ -1437,6 +1461,26 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta
return WithinDurationf(a.t, expected, actual, delta, msg, args...)
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinRange(a.t, actual, start, end, msgAndArgs...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinRangef(a.t, actual, start, end, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {

View File

@@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// IsNonIncreasing asserts that the collection is not increasing
@@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// IsDecreasing asserts that the collection is decreasing
@@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// IsNonDecreasing asserts that the collection is not decreasing
@@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"math"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
@@ -144,7 +145,8 @@ func CallerInfo() []string {
if len(parts) > 1 {
dir := parts[len(parts)-2]
if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" {
callers = append(callers, fmt.Sprintf("%s:%d", file, line))
path, _ := filepath.Abs(file)
callers = append(callers, fmt.Sprintf("%s:%d", path, line))
}
}
@@ -563,16 +565,17 @@ func isEmpty(object interface{}) bool {
switch objValue.Kind() {
// collection types are empty when they have no element
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0
// pointers are empty if nil or if the value they point to is empty
// pointers are empty if nil or if the value they point to is empty
case reflect.Ptr:
if objValue.IsNil() {
return true
}
deref := objValue.Elem().Interface()
return isEmpty(deref)
// for all other types, compare against the zero value
// for all other types, compare against the zero value
// array types are empty when they match their zero-initialized state
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
@@ -718,10 +721,14 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
// return (false, false) if impossible.
// return (true, false) if element was not found.
// return (true, true) if element was found.
func includeElement(list interface{}, element interface{}) (ok, found bool) {
func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list)
listKind := reflect.TypeOf(list).Kind()
listType := reflect.TypeOf(list)
if listType == nil {
return false, false
}
listKind := listType.Kind()
defer func() {
if e := recover(); e != nil {
ok = false
@@ -764,7 +771,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
}
@@ -787,7 +794,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
}
@@ -811,7 +818,6 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true // we consider nil to be equal to the nil set
}
subsetValue := reflect.ValueOf(subset)
defer func() {
if e := recover(); e != nil {
ok = false
@@ -821,17 +827,35 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice {
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}
if subsetKind != reflect.Array && subsetKind != reflect.Slice {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
if !ObjectsAreEqual(subsetElement, listElement) {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...)
}
}
return true
}
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -852,10 +876,9 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
h.Helper()
}
if subset == nil {
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...)
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
defer func() {
if e := recover(); e != nil {
ok = false
@@ -865,17 +888,35 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice {
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}
if subsetKind != reflect.Array && subsetKind != reflect.Slice {
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
if !ObjectsAreEqual(subsetElement, listElement) {
return true
}
}
return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...)
}
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -1000,27 +1041,21 @@ func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
type PanicTestFunc func()
// didPanic returns true if the function passed to it panics. Otherwise, it returns false.
func didPanic(f PanicTestFunc) (bool, interface{}, string) {
didPanic := false
var message interface{}
var stack string
func() {
defer func() {
if message = recover(); message != nil {
didPanic = true
stack = string(debug.Stack())
}
}()
// call the target function
f()
func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) {
didPanic = true
defer func() {
message = recover()
if didPanic {
stack = string(debug.Stack())
}
}()
return didPanic, message, stack
// call the target function
f()
didPanic = false
return
}
// Panics asserts that the code inside the specified PanicTestFunc panics.
@@ -1111,6 +1146,27 @@ func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration,
return true
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func WithinRange(t TestingT, actual, start, end time.Time, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if end.Before(start) {
return Fail(t, "Start should be before end", msgAndArgs...)
}
if actual.Before(start) {
return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is before the range", actual, start, end), msgAndArgs...)
} else if actual.After(end) {
return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is after the range", actual, start, end), msgAndArgs...)
}
return true
}
func toFloat(x interface{}) (float64, bool) {
var xf float64
xok := true
@@ -1161,11 +1217,15 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
bf, bok := toFloat(actual)
if !aok || !bok {
return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...)
return Fail(t, "Parameters must be numerical", msgAndArgs...)
}
if math.IsNaN(af) && math.IsNaN(bf) {
return true
}
if math.IsNaN(af) {
return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...)
return Fail(t, "Expected must not be NaN", msgAndArgs...)
}
if math.IsNaN(bf) {
@@ -1188,7 +1248,7 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1250,8 +1310,12 @@ func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, m
func calcRelativeError(expected, actual interface{}) (float64, error) {
af, aok := toFloat(expected)
if !aok {
return 0, fmt.Errorf("expected value %q cannot be converted to float", expected)
bf, bok := toFloat(actual)
if !aok || !bok {
return 0, fmt.Errorf("Parameters must be numerical")
}
if math.IsNaN(af) && math.IsNaN(bf) {
return 0, nil
}
if math.IsNaN(af) {
return 0, errors.New("expected value must not be NaN")
@@ -1259,10 +1323,6 @@ func calcRelativeError(expected, actual interface{}) (float64, error) {
if af == 0 {
return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error")
}
bf, bok := toFloat(actual)
if !bok {
return 0, fmt.Errorf("actual value %q cannot be converted to float", actual)
}
if math.IsNaN(bf) {
return 0, errors.New("actual value must not be NaN")
}
@@ -1298,7 +1358,7 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1375,6 +1435,27 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
return true
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !Error(t, theError, msgAndArgs...) {
return false
}
actual := theError.Error()
if !strings.Contains(actual, contains) {
return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...)
}
return true
}
// matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool {
@@ -1588,12 +1669,17 @@ func diff(expected interface{}, actual interface{}) string {
}
var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
switch et {
case reflect.TypeOf(""):
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
case reflect.TypeOf(time.Time{}):
e = spewConfigStringerEnabled.Sdump(expected)
a = spewConfigStringerEnabled.Sdump(actual)
default:
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
}
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
@@ -1625,6 +1711,14 @@ var spewConfig = spew.ConfigState{
MaxDepth: 10,
}
var spewConfigStringerEnabled = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
MaxDepth: 10,
}
type tHelper interface {
Helper()
}

View File

@@ -280,6 +280,36 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
t.FailNow()
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContains(t, theError, contains, msgAndArgs...) {
return
}
t.FailNow()
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContainsf(t, theError, contains, msg, args...) {
return
}
t.FailNow()
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) {
@@ -1834,6 +1864,32 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
t.FailNow()
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func WithinRange(t TestingT, actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.WithinRange(t, actual, start, end, msgAndArgs...) {
return
}
t.FailNow()
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.WithinRangef(t, actual, start, end, msg, args...) {
return
}
t.FailNow()
}
// YAMLEq asserts that two YAML strings are equivalent.
func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {

View File

@@ -223,6 +223,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) {
@@ -1438,6 +1462,26 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta
WithinDurationf(a.t, expected, actual, delta, msg, args...)
}
// WithinRange asserts that a time is within a time range (inclusive).
//
// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinRange(a.t, actual, start, end, msgAndArgs...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinRangef(a.t, actual, start, end, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {

31
vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s generated vendored Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
// +build darwin freebsd netbsd openbsd
// +build gc
#include "textflag.h"
//
// System call support for ppc64, BSD
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo && !aix
// +build gccgo,!aix
//go:build gccgo && !aix && !hurd
// +build gccgo,!aix,!hurd
package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build gccgo
// +build !aix
//go:build gccgo && !aix && !hurd
// +build gccgo,!aix,!hurd
#include <errno.h>
#include <stdint.h>

70
vendor/golang.org/x/sys/unix/ioctl_signed.go generated vendored Normal file
View File

@@ -0,0 +1,70 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || solaris
// +build aix solaris
package unix
import (
"unsafe"
)
// ioctl itself should not be exposed directly, but additional get/set
// functions for specific types are permissible.
// IoctlSetInt performs an ioctl operation which sets an integer value
// on fd, using the specified request number.
func IoctlSetInt(fd int, req int, value int) error {
return ioctl(fd, req, uintptr(value))
}
// IoctlSetPointerInt performs an ioctl operation which sets an
// integer value on fd, using the specified request number. The ioctl
// argument is called with a pointer to the integer value, rather than
// passing the integer value directly.
func IoctlSetPointerInt(fd int, req int, value int) error {
v := int32(value)
return ioctlPtr(fd, req, unsafe.Pointer(&v))
}
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
//
// To change fd's window size, the req argument should be TIOCSWINSZ.
func IoctlSetWinsize(fd int, req int, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlSetTermios performs an ioctl on fd with a *Termios.
//
// The req value will usually be TCSETA or TIOCSETA.
func IoctlSetTermios(fd int, req int, value *Termios) error {
// TODO: if we get the chance, remove the req parameter.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlGetInt performs an ioctl operation which gets an integer value
// from fd, using the specified request number.
//
// A few ioctl requests use the return value as an output parameter;
// for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req int) (int, error) {
var value int
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetWinsize(fd int, req int) (*Winsize, error) {
var value Winsize
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
func IoctlGetTermios(fd int, req int) (*Termios, error) {
var value Termios
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}

View File

@@ -2,13 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd
// +build darwin dragonfly freebsd hurd linux netbsd openbsd
package unix
import (
"runtime"
"unsafe"
)
@@ -27,7 +26,7 @@ func IoctlSetInt(fd int, req uint, value int) error {
// passing the integer value directly.
func IoctlSetPointerInt(fd int, req uint, value int) error {
v := int32(value)
return ioctl(fd, req, uintptr(unsafe.Pointer(&v)))
return ioctlPtr(fd, req, unsafe.Pointer(&v))
}
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
@@ -36,9 +35,7 @@ func IoctlSetPointerInt(fd int, req uint, value int) error {
func IoctlSetWinsize(fd int, req uint, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ.
err := ioctl(fd, req, uintptr(unsafe.Pointer(value)))
runtime.KeepAlive(value)
return err
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlSetTermios performs an ioctl on fd with a *Termios.
@@ -46,9 +43,7 @@ func IoctlSetWinsize(fd int, req uint, value *Winsize) error {
// The req value will usually be TCSETA or TIOCSETA.
func IoctlSetTermios(fd int, req uint, value *Termios) error {
// TODO: if we get the chance, remove the req parameter.
err := ioctl(fd, req, uintptr(unsafe.Pointer(value)))
runtime.KeepAlive(value)
return err
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlGetInt performs an ioctl operation which gets an integer value
@@ -58,18 +53,18 @@ func IoctlSetTermios(fd int, req uint, value *Termios) error {
// for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req uint) (int, error) {
var value int
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value)))
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetWinsize(fd int, req uint) (*Winsize, error) {
var value Winsize
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value)))
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
func IoctlGetTermios(fd int, req uint) (*Termios, error) {
var value Termios
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value)))
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}

View File

@@ -17,25 +17,23 @@ import (
// IoctlSetInt performs an ioctl operation which sets an integer value
// on fd, using the specified request number.
func IoctlSetInt(fd int, req uint, value int) error {
func IoctlSetInt(fd int, req int, value int) error {
return ioctl(fd, req, uintptr(value))
}
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
//
// To change fd's window size, the req argument should be TIOCSWINSZ.
func IoctlSetWinsize(fd int, req uint, value *Winsize) error {
func IoctlSetWinsize(fd int, req int, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ.
err := ioctl(fd, req, uintptr(unsafe.Pointer(value)))
runtime.KeepAlive(value)
return err
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlSetTermios performs an ioctl on fd with a *Termios.
//
// The req value is expected to be TCSETS, TCSETSW, or TCSETSF
func IoctlSetTermios(fd int, req uint, value *Termios) error {
func IoctlSetTermios(fd int, req int, value *Termios) error {
if (req != TCSETS) && (req != TCSETSW) && (req != TCSETSF) {
return ENOSYS
}
@@ -49,22 +47,22 @@ func IoctlSetTermios(fd int, req uint, value *Termios) error {
//
// A few ioctl requests use the return value as an output parameter;
// for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req uint) (int, error) {
func IoctlGetInt(fd int, req int) (int, error) {
var value int
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value)))
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetWinsize(fd int, req uint) (*Winsize, error) {
func IoctlGetWinsize(fd int, req int) (*Winsize, error) {
var value Winsize
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value)))
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
// IoctlGetTermios performs an ioctl on fd with a *Termios.
//
// The req value is expected to be TCGETS
func IoctlGetTermios(fd int, req uint) (*Termios, error) {
func IoctlGetTermios(fd int, req int) (*Termios, error) {
var value Termios
if req != TCGETS {
return &value, ENOSYS

View File

@@ -174,10 +174,28 @@ openbsd_arm64)
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_mips64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_ppc64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_riscv64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"

View File

@@ -66,6 +66,7 @@ includes_Darwin='
#include <sys/ptrace.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <sys/sockio.h>
#include <sys/sys_domain.h>
@@ -521,6 +522,7 @@ ccflags="$@"
$2 ~ /^NFC_(GENL|PROTO|COMM|RF|SE|DIRECTION|LLCP|SOCKPROTO)_/ ||
$2 ~ /^NFC_.*_(MAX)?SIZE$/ ||
$2 ~ /^RAW_PAYLOAD_/ ||
$2 ~ /^[US]F_/ ||
$2 ~ /^TP_STATUS_/ ||
$2 ~ /^FALLOC_/ ||
$2 ~ /^ICMPV?6?_(FILTER|SEC)/ ||
@@ -642,7 +644,7 @@ errors=$(
signals=$(
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort
)
@@ -652,7 +654,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags |
sort >_error.grep
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort >_signal.grep
echo '// mkerrors.sh' "$@"

View File

@@ -7,6 +7,12 @@
package unix
import "unsafe"
func ptrace(request int, pid int, addr uintptr, data uintptr) error {
return ptrace1(request, pid, addr, data)
}
func ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) error {
return ptrace1Ptr(request, pid, addr, data)
}

View File

@@ -7,6 +7,12 @@
package unix
import "unsafe"
func ptrace(request int, pid int, addr uintptr, data uintptr) (err error) {
return ENOTSUP
}
func ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) (err error) {
return ENOTSUP
}

View File

@@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) {
return msgs, nil
}
// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header,
// message data (a slice of b), and the remainder of b after that single message.
// When there are no remaining messages, len(remainder) == 0.
func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) {
h, dbuf, err := socketControlMessageHeaderAndData(b)
if err != nil {
return Cmsghdr{}, nil, nil, err
}
if i := cmsgAlignOf(int(h.Len)); i < len(b) {
remainder = b[i:]
}
return *h, dbuf, remainder, nil
}
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {

View File

@@ -292,9 +292,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) {
break
}
}
bytes := (*[len(pp.Path)]byte)(unsafe.Pointer(&pp.Path[0]))[0:n]
sa.Name = string(bytes)
sa.Name = string(unsafe.Slice((*byte)(unsafe.Pointer(&pp.Path[0])), n))
return sa, nil
case AF_INET:
@@ -410,7 +408,8 @@ func (w WaitStatus) CoreDump() bool { return w&0x80 == 0x80 }
func (w WaitStatus) TrapCause() int { return -1 }
//sys ioctl(fd int, req uint, arg uintptr) (err error)
//sys ioctl(fd int, req int, arg uintptr) (err error)
//sys ioctlPtr(fd int, req int, arg unsafe.Pointer) (err error) = ioctl
// fcntl must never be called with cmd=F_DUP2FD because it doesn't work on AIX
// There is no way to create a custom fcntl and to keep //sys fcntl easily,

View File

@@ -8,7 +8,6 @@
package unix
//sysnb Getrlimit(resource int, rlim *Rlimit) (err error) = getrlimit64
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error) = setrlimit64
//sys Seek(fd int, offset int64, whence int) (off int64, err error) = lseek64
//sys mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error)

View File

@@ -8,7 +8,6 @@
package unix
//sysnb Getrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sys Seek(fd int, offset int64, whence int) (off int64, err error) = lseek
//sys mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error) = mmap64

View File

@@ -245,8 +245,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) {
break
}
}
bytes := (*[len(pp.Path)]byte)(unsafe.Pointer(&pp.Path[0]))[0:n]
sa.Name = string(bytes)
sa.Name = string(unsafe.Slice((*byte)(unsafe.Pointer(&pp.Path[0])), n))
return sa, nil
case AF_INET:

View File

@@ -14,7 +14,6 @@ package unix
import (
"fmt"
"runtime"
"syscall"
"unsafe"
)
@@ -230,6 +229,7 @@ func direntNamlen(buf []byte) (uint64, bool) {
func PtraceAttach(pid int) (err error) { return ptrace(PT_ATTACH, pid, 0, 0) }
func PtraceDetach(pid int) (err error) { return ptrace(PT_DETACH, pid, 0, 0) }
func PtraceDenyAttach() (err error) { return ptrace(PT_DENY_ATTACH, 0, 0, 0) }
//sysnb pipe(p *[2]int32) (err error)
@@ -375,11 +375,10 @@ func Flistxattr(fd int, dest []byte) (sz int, err error) {
func Kill(pid int, signum syscall.Signal) (err error) { return kill(pid, int(signum), 1) }
//sys ioctl(fd int, req uint, arg uintptr) (err error)
//sys ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error) = SYS_IOCTL
func IoctlCtlInfo(fd int, ctlInfo *CtlInfo) error {
err := ioctl(fd, CTLIOCGINFO, uintptr(unsafe.Pointer(ctlInfo)))
runtime.KeepAlive(ctlInfo)
return err
return ioctlPtr(fd, CTLIOCGINFO, unsafe.Pointer(ctlInfo))
}
// IfreqMTU is struct ifreq used to get or set a network device's MTU.
@@ -393,16 +392,14 @@ type IfreqMTU struct {
func IoctlGetIfreqMTU(fd int, ifname string) (*IfreqMTU, error) {
var ifreq IfreqMTU
copy(ifreq.Name[:], ifname)
err := ioctl(fd, SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
err := ioctlPtr(fd, SIOCGIFMTU, unsafe.Pointer(&ifreq))
return &ifreq, err
}
// IoctlSetIfreqMTU performs the SIOCSIFMTU ioctl operation on fd to set the MTU
// of the network device specified by ifreq.Name.
func IoctlSetIfreqMTU(fd int, ifreq *IfreqMTU) error {
err := ioctl(fd, SIOCSIFMTU, uintptr(unsafe.Pointer(ifreq)))
runtime.KeepAlive(ifreq)
return err
return ioctlPtr(fd, SIOCSIFMTU, unsafe.Pointer(ifreq))
}
//sys sysctl(mib []_C_int, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) = SYS_SYSCTL
@@ -616,6 +613,7 @@ func SysctlKinfoProcSlice(name string, args ...int) ([]KinfoProc, error) {
//sys Rmdir(path string) (err error)
//sys Seek(fd int, offset int64, whence int) (newoffset int64, err error) = SYS_LSEEK
//sys Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error)
//sys Setattrlist(path string, attrlist *Attrlist, attrBuf []byte, options int) (err error)
//sys Setegid(egid int) (err error)
//sysnb Seteuid(euid int) (err error)
//sysnb Setgid(gid int) (err error)
@@ -625,7 +623,6 @@ func SysctlKinfoProcSlice(name string, args ...int) ([]KinfoProc, error) {
//sys Setprivexec(flag int) (err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sysnb Setrlimit(which int, lim *Rlimit) (err error)
//sysnb Setsid() (pid int, err error)
//sysnb Settimeofday(tp *Timeval) (err error)
//sysnb Setuid(uid int) (err error)
@@ -679,7 +676,6 @@ func SysctlKinfoProcSlice(name string, args ...int) ([]KinfoProc, error) {
// Kqueue_from_portset_np
// Kqueue_portset
// Getattrlist
// Setattrlist
// Getdirentriesattr
// Searchfs
// Delete

Some files were not shown because too many files have changed in this diff Show More