Compare commits

...

42 Commits

Author SHA1 Message Date
Evan Lezar
206ac20e78 Merge branch 'allow-centos7-aarch64-scan-failure' into 'release-1.13'
Allow failure for centos7-aarch64 scans

See merge request nvidia/container-toolkit/container-toolkit!441
2023-07-12 19:55:58 +00:00
Evan Lezar
c3c23d647f Allow failure for centos7-aarch64 scans
For the release-1.13 branch, we don't build aarch64 images for centos7.

This means that, depending on the docker version, a docker pull fails if
the platform is specified.

As a simple workaround, we allow failure for this scan step.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-07-12 21:54:14 +02:00
Evan Lezar
f2a42dc924 Merge branch 'bump-cuda-12.2.0' into 'release-1.13'
Merge branch 'bump-cuda-12.2.0' into 'main'

See merge request nvidia/container-toolkit/container-toolkit!440
2023-07-11 19:18:28 +00:00
Evan Lezar
752afe8ca9 Bump version to v1.13.4
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-07-11 20:34:38 +02:00
Evan Lezar
78940d0a95 Bump libnvidia-container to v1.13.4
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-07-11 20:34:38 +02:00
Evan Lezar
4e0d6f3934 Merge branch 'bump-cuda-12.2.0' into 'main'
Bump cuda base image to 12.2.0

See merge request nvidia/container-toolkit/container-toolkit!435
2023-07-11 20:34:36 +02:00
Evan Lezar
c5a93b8d70 Merge branch 'cherry-pick-wsl2' into 'release-1.13'
Backport CDI fixes for 1.13.3 release

See merge request nvidia/container-toolkit/container-toolkit!429
2023-06-27 18:40:19 +00:00
Evan Lezar
cc06766f25 Merge branch 'fix-load-kernel-modules' into 'main'
Split internal system package

See merge request nvidia/container-toolkit/container-toolkit!420
2023-06-27 17:36:57 +02:00
Evan Lezar
7c807c2c22 Merge branch 'CNT-4302/cdi-only' into 'main'
Skip additional modifications in CDI mode

See merge request nvidia/container-toolkit/container-toolkit!413
2023-06-27 17:08:14 +02:00
Evan Lezar
89781ad6a3 Merge branch 'use-major-minor-for-cuda-version' into 'main'
Use *.* pattern when locating libcuda.so

See merge request nvidia/container-toolkit/container-toolkit!397
2023-06-27 16:59:35 +02:00
Evan Lezar
f677245d60 Merge branch 'fix-multiple-driver-roots-wsl' into 'main'
Fix bug with multiple driver store paths

See merge request nvidia/container-toolkit/container-toolkit!425
2023-06-27 16:59:33 +02:00
Evan Lezar
9d31bd4cc3 Merge branch 'fix-cdi-permissions' into 'main'
Properly set spec permissions

See merge request nvidia/container-toolkit/container-toolkit!383
2023-06-27 16:27:28 +02:00
Carlos Eduardo Arango Gutierrez
b063fa40b1 Merge branch 'fix-cdi-spec-permissions' into 'main'
Generate CDI specifications with 644 permissions to allow non-root clients to consume them

See merge request nvidia/container-toolkit/container-toolkit!381
2023-06-27 16:27:27 +02:00
Evan Lezar
9b7904e0bb Merge branch 'CNT-4075' into 'release-1.13'
Allow same envars for all runtime configs

See merge request nvidia/container-toolkit/container-toolkit!419
2023-06-27 14:26:29 +00:00
Evan Lezar
a9ccef6090 Ensure common envvars have higher precedence
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:26 +02:00
Carlos Eduardo Arango Gutierrez
a4e13c5197 Add entry to changelog
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
2023-06-27 16:25:24 +02:00
Evan Lezar
c9a8b7f335 Ensure runtime dir is set for crio cleanup
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
591e610905 Remove unused constants and variables
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
524802df2b Rework restart logic
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
19ca4338f2 Add version info to config CLIs
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
56c533d7d4 Refactor toolking to setup and cleanup configs
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
b9b19494d0 Add runtimeDir as argument
Thsi change adds the --nvidia-runtime-dir as a command line
argument when configuring container runtimes in the toolkit container.
This removes the need to set it via the command line.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
47b6b01f48 Allow same envars for all runtime configs
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 16:25:00 +02:00
Evan Lezar
7a70850679 Merge branch 'bump-version-v1.13.3' into 'release-1.13'
Bump version to v1.13.3

See merge request nvidia/container-toolkit/container-toolkit!428
2023-06-27 14:14:02 +00:00
Evan Lezar
6052b3eba3 Update libnvidia-container
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 15:30:46 +02:00
Evan Lezar
4ae775d683 Bump version to 1.13.3
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-27 15:29:19 +02:00
Evan Lezar
d9cf2682f9 Merge branch 'remove-libnvidia-container-workaround' into 'release-1.13'
Merge branch 'revert-kitmaker-workaround' into 'main'

See merge request nvidia/container-toolkit/container-toolkit!417
2023-06-06 20:19:00 +00:00
Evan Lezar
1e5a6e1fa3 Merge branch 'revert-kitmaker-workaround' into 'main'
Remove workaround to add libnvidia-container0 to kitmaker archive

See merge request nvidia/container-toolkit/container-toolkit!378
2023-06-06 22:17:30 +02:00
Evan Lezar
5b81e30704 Merge branch 'cherry-pick-for-v1.13.2' into 'release-1.13'
Cherry pick changes for 1.13.2

See merge request nvidia/container-toolkit/container-toolkit!407
2023-06-06 19:03:42 +00:00
Evan Lezar
a34b08908e Update libnvidia-container
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-06 20:14:27 +02:00
Evan Lezar
234f7ebd5f Merge branch 'bump-cuda-base-image' into 'main'
Bump CUDA baseimage version to 12.1.1

See merge request nvidia/container-toolkit/container-toolkit!412
2023-06-01 14:49:08 +02:00
Evan Lezar
36d1b7d2a5 Merge branch 'treat-log-errors-as-non-fatal' into 'main'
Ignore errors when creating debug log file

See merge request nvidia/container-toolkit/container-toolkit!404
2023-06-01 14:48:33 +02:00
Evan Lezar
b776bf712e Merge branch 'add-mod-probe' into 'main'
Add option to load NVIDIA kernel modules

See merge request nvidia/container-toolkit/container-toolkit!409
2023-06-01 14:48:33 +02:00
Evan Lezar
90cbe938c3 Update CHANGELOG for cherry-pick
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 14:48:33 +02:00
Evan Lezar
8697267e6b Merge branch 'fix-device-symlinks' into 'main'
Fix creation of device symlinks in /dev/char

See merge request nvidia/container-toolkit/container-toolkit!399
2023-06-01 14:48:33 +02:00
Evan Lezar
2d7bb636b9 Merge branch 'CNT-4285/add-runtime-hook-path' into 'main'
Add nvidia-contianer-runtime-hook.path config option

See merge request nvidia/container-toolkit/container-toolkit!401
2023-06-01 14:48:33 +02:00
Evan Lezar
7386f86904 Merge branch 'bump-runc' into 'main'
Bump golang version and update dependencies

See merge request nvidia/container-toolkit/container-toolkit!377
2023-06-01 14:47:53 +02:00
Evan Lezar
2bcb2d633d Bump version to 1.13.2
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 08:26:59 +02:00
Evan Lezar
fac0697c93 Update libnvidia-container
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 08:26:55 +02:00
Evan Lezar
595692b9bd Set libnvidia-container branch to release-1.13
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-06-01 08:26:06 +02:00
Evan Lezar
f13d4402bd Merge branch 'update-release-1.13' into 'release-1.13'
Skip updating of components on release-1.13 branch

See merge request nvidia/container-toolkit/container-toolkit!405
2023-05-26 09:09:46 +00:00
Evan Lezar
968dce5a70 Skip updating of components
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2023-05-24 10:45:26 +02:00
268 changed files with 19826 additions and 3620 deletions

2
.gitmodules vendored
View File

@@ -1,7 +1,7 @@
[submodule "third_party/libnvidia-container"] [submodule "third_party/libnvidia-container"]
path = third_party/libnvidia-container path = third_party/libnvidia-container
url = https://gitlab.com/nvidia/container-toolkit/libnvidia-container.git url = https://gitlab.com/nvidia/container-toolkit/libnvidia-container.git
branch = main branch = release-1.13
[submodule "third_party/nvidia-container-runtime"] [submodule "third_party/nvidia-container-runtime"]
path = third_party/nvidia-container-runtime path = third_party/nvidia-container-runtime
url = https://gitlab.com/nvidia/container-toolkit/container-runtime.git url = https://gitlab.com/nvidia/container-toolkit/container-runtime.git

View File

@@ -146,6 +146,7 @@ scan-centos7-arm64:
needs: needs:
- image-centos7 - image-centos7
- scan-centos7-amd64 - scan-centos7-amd64
allow_failure: true
scan-ubuntu20.04-amd64: scan-ubuntu20.04-amd64:
extends: extends:

View File

@@ -1,5 +1,27 @@
# NVIDIA Container Toolkit Changelog # NVIDIA Container Toolkit Changelog
## v1.13.4
* [toolkit-container] Bump CUDA base image version to 12.2.0.
## v1.13.3
* Generate CDI specification files with `644` permissions to allow rootless applications (e.g. podman).
* Fix bug causing incorrect nvidia-smi symlink to be created on WSL2 systems with multiple driver roots.
* Fix bug when using driver versions that do not include a patch component in their version number.
* Skip additional modifications in CDI mode.
* Fix loading of kernel modules and creation of device nodes in containerized use cases.
* [toolkit-container] Allow same envars for all runtime configs
## v1.13.2
* Add `nvidia-container-runtime-hook.path` config option to specify NVIDIA Container Runtime Hook path explicitly.
* Fix bug in creation of `/dev/char` symlinks by failing operation if kernel modules are not loaded.
* Add option to load kernel modules when creating device nodes
* Add option to create device nodes when creating `/dev/char` symlinks
* Treat failures to open debug log files as non-fatal.
* Bump CUDA base image version to 12.1.1.
## v1.13.1 ## v1.13.1
* Update `update-ldcache` hook to only update ldcache if it exists. * Update `update-ldcache` hook to only update ldcache if it exists.

View File

@@ -172,7 +172,7 @@ func TestDuplicateHook(t *testing.T) {
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for // addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
// testing. // testing.
func addNVIDIAHook(spec *specs.Spec) error { func addNVIDIAHook(spec *specs.Spec) error {
m := modifier.NewStableRuntimeModifier(logrus.StandardLogger()) m := modifier.NewStableRuntimeModifier(logrus.StandardLogger(), nvidiaHook)
return m.Modify(spec) return m.Modify(spec)
} }

View File

@@ -251,6 +251,7 @@ func (m command) generateSpec(cfg *config) (spec.Interface, error) {
spec.WithDeviceSpecs(deviceSpecs), spec.WithDeviceSpecs(deviceSpecs),
spec.WithEdits(*commonEdits.ContainerEdits), spec.WithEdits(*commonEdits.ContainerEdits),
spec.WithFormat(cfg.format), spec.WithFormat(cfg.format),
spec.WithPermissions(0644),
) )
} }

View File

@@ -28,29 +28,40 @@ import (
type allPossible struct { type allPossible struct {
logger *logrus.Logger logger *logrus.Logger
driverRoot string devRoot string
deviceMajors devices.Devices deviceMajors devices.Devices
migCaps nvcaps.MigCaps migCaps nvcaps.MigCaps
} }
// newAllPossible returns a new allPossible device node lister. // newAllPossible returns a new allPossible device node lister.
// This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices. // This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices.
func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error) { func newAllPossible(logger *logrus.Logger, devRoot string) (nodeLister, error) {
deviceMajors, err := devices.GetNVIDIADevices() deviceMajors, err := devices.GetNVIDIADevices()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed reading device majors: %v", err) return nil, fmt.Errorf("failed reading device majors: %v", err)
} }
var requiredMajors []devices.Name
migCaps, err := nvcaps.NewMigCaps() migCaps, err := nvcaps.NewMigCaps()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read MIG caps: %v", err) return nil, fmt.Errorf("failed to read MIG caps: %v", err)
} }
if migCaps == nil { if migCaps == nil {
migCaps = make(nvcaps.MigCaps) migCaps = make(nvcaps.MigCaps)
} else {
requiredMajors = append(requiredMajors, devices.NVIDIACaps)
}
requiredMajors = append(requiredMajors, devices.NVIDIAGPU, devices.NVIDIAUVM)
for _, name := range requiredMajors {
if !deviceMajors.Exists(name) {
return nil, fmt.Errorf("missing required device major %s", name)
}
} }
l := allPossible{ l := allPossible{
logger: logger, logger: logger,
driverRoot: driverRoot, devRoot: devRoot,
deviceMajors: deviceMajors, deviceMajors: deviceMajors,
migCaps: migCaps, migCaps: migCaps,
} }
@@ -61,7 +72,7 @@ func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error
// DeviceNodes returns a list of all possible device nodes for NVIDIA GPUs, control devices, and capability devices. // DeviceNodes returns a list of all possible device nodes for NVIDIA GPUs, control devices, and capability devices.
func (m allPossible) DeviceNodes() ([]deviceNode, error) { func (m allPossible) DeviceNodes() ([]deviceNode, error) {
gpus, err := nvpci.NewFrom( gpus, err := nvpci.NewFrom(
filepath.Join(m.driverRoot, nvpci.PCIDevicesRoot), filepath.Join(m.devRoot, nvpci.PCIDevicesRoot),
).GetGPUs() ).GetGPUs()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get GPU information: %v", err) return nil, fmt.Errorf("failed to get GPU information: %v", err)
@@ -69,7 +80,7 @@ func (m allPossible) DeviceNodes() ([]deviceNode, error) {
count := len(gpus) count := len(gpus)
if count == 0 { if count == 0 {
m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot) m.logger.Infof("No NVIDIA devices found in %s", m.devRoot)
return nil, nil return nil, nil
} }
@@ -168,7 +179,7 @@ func (m allPossible) newDeviceNode(deviceName devices.Name, path string, minor i
major, _ := m.deviceMajors.Get(deviceName) major, _ := m.deviceMajors.Get(deviceName)
return deviceNode{ return deviceNode{
path: filepath.Join(m.driverRoot, path), path: filepath.Join(m.devRoot, path),
major: uint32(major), major: uint32(major),
minor: uint32(minor), minor: uint32(minor),
} }

View File

@@ -24,6 +24,8 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@@ -38,11 +40,13 @@ type command struct {
} }
type config struct { type config struct {
devCharPath string devCharPath string
driverRoot string driverRoot string
dryRun bool dryRun bool
watch bool watch bool
createAll bool createAll bool
createDeviceNodes bool
loadKernelModules bool
} }
// NewCommand constructs a command sub-command with the specified logger // NewCommand constructs a command sub-command with the specified logger
@@ -97,6 +101,18 @@ func (m command) build() *cli.Command {
Destination: &cfg.createAll, Destination: &cfg.createAll,
EnvVars: []string{"CREATE_ALL"}, EnvVars: []string{"CREATE_ALL"},
}, },
&cli.BoolFlag{
Name: "load-kernel-modules",
Usage: "Load the NVIDIA kernel modules before creating symlinks. This is only applicable when --create-all is set.",
Destination: &cfg.loadKernelModules,
EnvVars: []string{"LOAD_KERNEL_MODULES"},
},
&cli.BoolFlag{
Name: "create-device-nodes",
Usage: "Create the NVIDIA control device nodes in the driver root if they do not exist. This is only applicable when --create-all is set",
Destination: &cfg.createDeviceNodes,
EnvVars: []string{"CREATE_DEVICE_NODES"},
},
&cli.BoolFlag{ &cli.BoolFlag{
Name: "dry-run", Name: "dry-run",
Usage: "If set, the command will not create any symlinks.", Usage: "If set, the command will not create any symlinks.",
@@ -114,6 +130,16 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error {
return fmt.Errorf("create-all and watch are mutually exclusive") return fmt.Errorf("create-all and watch are mutually exclusive")
} }
if cfg.loadKernelModules && !cfg.createAll {
m.logger.Warn("load-kernel-modules is only applicable when create-all is set; ignoring")
cfg.loadKernelModules = false
}
if cfg.createDeviceNodes && !cfg.createAll {
m.logger.Warn("create-device-nodes is only applicable when create-all is set; ignoring")
cfg.createDeviceNodes = false
}
return nil return nil
} }
@@ -137,6 +163,8 @@ func (m command) run(c *cli.Context, cfg *config) error {
WithDriverRoot(cfg.driverRoot), WithDriverRoot(cfg.driverRoot),
WithDryRun(cfg.dryRun), WithDryRun(cfg.dryRun),
WithCreateAll(cfg.createAll), WithCreateAll(cfg.createAll),
WithLoadKernelModules(cfg.loadKernelModules),
WithCreateDeviceNodes(cfg.createDeviceNodes),
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to create symlink creator: %v", err) return fmt.Errorf("failed to create symlink creator: %v", err)
@@ -186,12 +214,15 @@ create:
} }
type linkCreator struct { type linkCreator struct {
logger *logrus.Logger logger *logrus.Logger
lister nodeLister lister nodeLister
driverRoot string driverRoot string
devCharPath string devRoot string
dryRun bool devCharPath string
createAll bool dryRun bool
createAll bool
createDeviceNodes bool
loadKernelModules bool
} }
// Creator is an interface for creating symlinks to /dev/nv* devices in /dev/char. // Creator is an interface for creating symlinks to /dev/nv* devices in /dev/char.
@@ -214,29 +245,76 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) {
if c.driverRoot == "" { if c.driverRoot == "" {
c.driverRoot = "/" c.driverRoot = "/"
} }
if c.devRoot == "" {
c.devRoot = "/"
}
if c.devCharPath == "" { if c.devCharPath == "" {
c.devCharPath = defaultDevCharPath c.devCharPath = defaultDevCharPath
} }
if err := c.setup(); err != nil {
return nil, err
}
if c.createAll { if c.createAll {
lister, err := newAllPossible(c.logger, c.driverRoot) lister, err := newAllPossible(c.logger, c.devRoot)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create all possible device lister: %v", err) return nil, fmt.Errorf("failed to create all possible device lister: %v", err)
} }
c.lister = lister c.lister = lister
} else { } else {
c.lister = existing{c.logger, c.driverRoot} c.lister = existing{c.logger, c.devRoot}
} }
return c, nil return c, nil
} }
func (m linkCreator) setup() error {
if !m.loadKernelModules && !m.createDeviceNodes {
return nil
}
if m.loadKernelModules {
modules := nvmodules.New(
nvmodules.WithLogger(m.logger),
nvmodules.WithDryRun(m.dryRun),
nvmodules.WithRoot(m.driverRoot),
)
if err := modules.LoadAll(); err != nil {
return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err)
}
}
if m.createDeviceNodes {
devices, err := nvdevices.New(
nvdevices.WithLogger(m.logger),
nvdevices.WithDryRun(m.dryRun),
nvdevices.WithDevRoot(m.devRoot),
)
if err != nil {
return err
}
if err := devices.CreateNVIDIAControlDevices(); err != nil {
return fmt.Errorf("failed to create NVIDIA device nodes: %v", err)
}
}
return nil
}
// WithDriverRoot sets the driver root path. // WithDriverRoot sets the driver root path.
// This is the path in which kernel modules must be loaded.
func WithDriverRoot(root string) Option { func WithDriverRoot(root string) Option {
return func(c *linkCreator) { return func(c *linkCreator) {
c.driverRoot = root c.driverRoot = root
} }
} }
// WithDevRoot sets the root path for the /dev directory.
func WithDevRoot(root string) Option {
return func(c *linkCreator) {
c.devRoot = root
}
}
// WithDevCharPath sets the path at which the symlinks will be created. // WithDevCharPath sets the path at which the symlinks will be created.
func WithDevCharPath(path string) Option { func WithDevCharPath(path string) Option {
return func(c *linkCreator) { return func(c *linkCreator) {
@@ -265,6 +343,20 @@ func WithCreateAll(createAll bool) Option {
} }
} }
// WithLoadKernelModules sets the loadKernelModules flag for the linkCreator.
func WithLoadKernelModules(loadKernelModules bool) Option {
return func(lc *linkCreator) {
lc.loadKernelModules = loadKernelModules
}
}
// WithCreateDeviceNodes sets the createDeviceNodes flag for the linkCreator.
func WithCreateDeviceNodes(createDeviceNodes bool) Option {
return func(lc *linkCreator) {
lc.createDeviceNodes = createDeviceNodes
}
}
// CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root. // CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root.
func (m linkCreator) CreateLinks() error { func (m linkCreator) CreateLinks() error {
deviceNodes, err := m.lister.DeviceNodes() deviceNodes, err := m.lister.DeviceNodes()

View File

@@ -30,8 +30,8 @@ type nodeLister interface {
} }
type existing struct { type existing struct {
logger *logrus.Logger logger *logrus.Logger
driverRoot string devRoot string
} }
// DeviceNodes returns a list of NVIDIA device nodes in the specified root. // DeviceNodes returns a list of NVIDIA device nodes in the specified root.
@@ -39,7 +39,7 @@ type existing struct {
func (m existing) DeviceNodes() ([]deviceNode, error) { func (m existing) DeviceNodes() ([]deviceNode, error) {
locator := lookup.NewCharDeviceLocator( locator := lookup.NewCharDeviceLocator(
lookup.WithLogger(m.logger), lookup.WithLogger(m.logger),
lookup.WithRoot(m.driverRoot), lookup.WithRoot(m.devRoot),
lookup.WithOptional(true), lookup.WithOptional(true),
) )
@@ -54,7 +54,7 @@ func (m existing) DeviceNodes() ([]deviceNode, error) {
} }
if len(devices) == 0 && len(capDevices) == 0 { if len(devices) == 0 && len(capDevices) == 0 {
m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot) m.logger.Infof("No NVIDIA devices found in %s", m.devRoot)
return nil, nil return nil, nil
} }

View File

@@ -19,7 +19,8 @@ package createdevicenodes
import ( import (
"fmt" "fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system" "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
@@ -34,6 +35,8 @@ type options struct {
dryRun bool dryRun bool
control bool control bool
loadKernelModules bool
} }
// NewCommand constructs a command sub-command with the specified logger // NewCommand constructs a command sub-command with the specified logger
@@ -72,6 +75,11 @@ func (m command) build() *cli.Command {
Usage: "create all control device nodes: nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools", Usage: "create all control device nodes: nvidiactl, nvidia-modeset, nvidia-uvm, nvidia-uvm-tools",
Destination: &opts.control, Destination: &opts.control,
}, },
&cli.BoolFlag{
Name: "load-kernel-modules",
Usage: "load the NVIDIA Kernel Modules before creating devices nodes",
Destination: &opts.loadKernelModules,
},
&cli.BoolFlag{ &cli.BoolFlag{
Name: "dry-run", Name: "dry-run",
Usage: "if set, the command will not create any symlinks.", Usage: "if set, the command will not create any symlinks.",
@@ -89,18 +97,29 @@ func (m command) validateFlags(r *cli.Context, opts *options) error {
} }
func (m command) run(c *cli.Context, opts *options) error { func (m command) run(c *cli.Context, opts *options) error {
s, err := system.New( if opts.loadKernelModules {
system.WithLogger(m.logger), modules := nvmodules.New(
system.WithDryRun(opts.dryRun), nvmodules.WithLogger(m.logger),
) nvmodules.WithDryRun(opts.dryRun),
if err != nil { nvmodules.WithRoot(opts.driverRoot),
return fmt.Errorf("failed to create library: %v", err) )
if err := modules.LoadAll(); err != nil {
return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err)
}
} }
if opts.control { if opts.control {
devices, err := nvdevices.New(
nvdevices.WithLogger(m.logger),
nvdevices.WithDryRun(opts.dryRun),
nvdevices.WithDevRoot(opts.driverRoot),
)
if err != nil {
return err
}
m.logger.Infof("Creating control device nodes at %s", opts.driverRoot) m.logger.Infof("Creating control device nodes at %s", opts.driverRoot)
if err := s.CreateNVIDIAControlDeviceNodesAt(opts.driverRoot); err != nil { if err := devices.CreateNVIDIAControlDevices(); err != nil {
return fmt.Errorf("failed to create control device nodes: %v", err) return fmt.Errorf("failed to create NVIDIA control device nodes: %v", err)
} }
} }
return nil return nil

20
go.mod
View File

@@ -1,31 +1,30 @@
module github.com/NVIDIA/nvidia-container-toolkit module github.com/NVIDIA/nvidia-container-toolkit
go 1.18 go 1.20
require ( require (
github.com/BurntSushi/toml v1.0.0 github.com/BurntSushi/toml v1.2.1
github.com/NVIDIA/go-nvml v0.12.0-0 github.com/NVIDIA/go-nvml v0.12.0-0
github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a
github.com/fsnotify/fsnotify v1.5.4 github.com/fsnotify/fsnotify v1.5.4
github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb github.com/opencontainers/runtime-spec v1.1.0-rc.2
github.com/pelletier/go-toml v1.9.4 github.com/pelletier/go-toml v1.9.4
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.8.1
github.com/urfave/cli/v2 v2.3.0 github.com/urfave/cli/v2 v2.3.0
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438 gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438
golang.org/x/mod v0.5.0 golang.org/x/mod v0.5.0
golang.org/x/sys v0.0.0-20220927170352-d9d178bc13c6 golang.org/x/sys v0.7.0
sigs.k8s.io/yaml v1.3.0
) )
require ( require (
github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/opencontainers/runc v1.1.4 // indirect github.com/opencontainers/runc v1.1.6 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 // indirect github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 // indirect
github.com/opencontainers/selinux v1.10.1 // indirect github.com/opencontainers/selinux v1.11.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 // indirect github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 // indirect
@@ -33,4 +32,5 @@ require (
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
) )

72
go.sum
View File

@@ -1,86 +1,76 @@
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.0.0 h1:dtDWrepsVPfW9H/4y7dDgFc2MBUSeJhlaDtK13CxFlU= github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak=
github.com/BurntSushi/toml v1.0.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/NVIDIA/go-nvml v0.11.6-0.0.20220823120812-7e2082095e82 h1:x751Xx1tdxkiA/sdkv2J769n21UbYKzVOpe9S/h1M3k=
github.com/NVIDIA/go-nvml v0.11.6-0.0.20220823120812-7e2082095e82/go.mod h1:hy7HYeQy335x6nEss0Ne3PYqleRa6Ct+VKD9RQ4nyFs= github.com/NVIDIA/go-nvml v0.11.6-0.0.20220823120812-7e2082095e82/go.mod h1:hy7HYeQy335x6nEss0Ne3PYqleRa6Ct+VKD9RQ4nyFs=
github.com/NVIDIA/go-nvml v0.12.0-0 h1:eHYNHbzAsMgWYshf6dEmTY66/GCXnORJFnzm3TNH4mc= github.com/NVIDIA/go-nvml v0.12.0-0 h1:eHYNHbzAsMgWYshf6dEmTY66/GCXnORJFnzm3TNH4mc=
github.com/NVIDIA/go-nvml v0.12.0-0/go.mod h1:hy7HYeQy335x6nEss0Ne3PYqleRa6Ct+VKD9RQ4nyFs= github.com/NVIDIA/go-nvml v0.12.0-0/go.mod h1:hy7HYeQy335x6nEss0Ne3PYqleRa6Ct+VKD9RQ4nyFs=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
github.com/checkpoint-restore/go-criu/v5 v5.3.0/go.mod h1:E/eQpaFtUKGOOSEBZgmKAcn+zUUwWxqcaKZlF54wK8E=
github.com/cilium/ebpf v0.7.0/go.mod h1:/oI2+1shJiTGAMgl6/RgJr36Eo1jzrRcAWbcXO2usCA=
github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a h1:sP3PcgyIkRlHqfF3Jfpe/7G8kf/qpzG4C8r94y9hLbE= github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a h1:sP3PcgyIkRlHqfF3Jfpe/7G8kf/qpzG4C8r94y9hLbE=
github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a/go.mod h1:xMRa4fJgXzSDFUCURSimOUgoSc+odohvO3uXT9xjqH0= github.com/container-orchestrated-devices/container-device-interface v0.5.4-0.20230111111500-5b3b5d81179a/go.mod h1:xMRa4fJgXzSDFUCURSimOUgoSc+odohvO3uXT9xjqH0=
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.1 h1:r/myEWzV9lfsM1tFLgDyu0atFtJ1fXn261LKYj/3DxU= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cyphar/filepath-securejoin v0.2.3/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k=
github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI=
github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.0.6/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mndrix/tap-go v0.0.0-20171203230836-629fa407e90b/go.mod h1:pzzDgJWZ34fGzaAZGFW22KVZDfyrYW+QABMrWnJBnSs= github.com/mndrix/tap-go v0.0.0-20171203230836-629fa407e90b/go.mod h1:pzzDgJWZ34fGzaAZGFW22KVZDfyrYW+QABMrWnJBnSs=
github.com/moby/sys/mountinfo v0.5.0/go.mod h1:3bMD3Rg+zkqx8MRYPi7Pyb0Ie97QEBmdxbhnCLlSvSU=
github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ= github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ=
github.com/opencontainers/runc v1.1.4 h1:nRCz/8sKg6K6jgYAFLDlXzPeITBZJyX28DBVhWD+5dg= github.com/opencontainers/runc v1.1.6 h1:XbhB8IfG/EsnhNvZtNdLB0GBw92GYEFvKlhaJk9jUgA=
github.com/opencontainers/runc v1.1.4/go.mod h1:1J5XiS+vdZ3wCyZybsuxXZWGrgSr8fFJHLXuG2PsnNg= github.com/opencontainers/runc v1.1.6/go.mod h1:CbUumNnWCuTGFukNXahoo/RFBZvDAgRh/smNYNOhA50=
github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb h1:1xSVPOd7/UA+39/hXEGnBJ13p6JFB0E1EvQFlrRDOXI=
github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.1.0-rc.2 h1:ucBtEms2tamYYW/SvGpvq9yUN0NEVL6oyLEwDcTSrk8=
github.com/opencontainers/runtime-spec v1.1.0-rc.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 h1:DmNGcqH3WDbV5k8OJ+esPWbqUOX5rMLR2PMvziDMJi0= github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 h1:DmNGcqH3WDbV5k8OJ+esPWbqUOX5rMLR2PMvziDMJi0=
github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626/go.mod h1:BRHJJd0E+cx42OybVYSgUvZmU0B8P9gZuRXlZUP7TKI= github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626/go.mod h1:BRHJJd0E+cx42OybVYSgUvZmU0B8P9gZuRXlZUP7TKI=
github.com/opencontainers/selinux v1.9.1/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI= github.com/opencontainers/selinux v1.9.1/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI=
github.com/opencontainers/selinux v1.10.0/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI= github.com/opencontainers/selinux v1.11.0 h1:+5Zbo97w3Lbmb3PeqQtpmTkMwsW5nRI3YaLpt7tQ7oU=
github.com/opencontainers/selinux v1.10.1 h1:09LIPVRP3uuZGQvgR+SgMSNBd1Eb3vlRbGqQpoHsF8w= github.com/opencontainers/selinux v1.11.0/go.mod h1:E5dMC3VPuVvVHDYmi78qvhJp8+M586T4DlDRYpFkyec=
github.com/opencontainers/selinux v1.10.1/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI=
github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/seccomp/libseccomp-golang v0.9.2-0.20220502022130-f33da4d89646/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 h1:kdXcSzyDtseVEc4yCz2qF8ZrQvIDBJLl4S1c3GCXmoI= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 h1:kdXcSzyDtseVEc4yCz2qF8ZrQvIDBJLl4S1c3GCXmoI=
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
github.com/urfave/cli v1.19.1/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.19.1/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M= github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
@@ -88,35 +78,19 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230119114711-6fe07bb33342 h1:083n9fJt2dWOpJd/X/q9Xgl5XtQLL22uSFYbzVqJssg=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230119114711-6fe07bb33342/go.mod h1:GStidGxhaqJhYFW1YpOnLvYCbL2EsM0od7IW4u7+JgU=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438 h1:+qRai7XRl8omFQVCeHcaWzL542Yw64vfmuXG+79ZCIc= gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438 h1:+qRai7XRl8omFQVCeHcaWzL542Yw64vfmuXG+79ZCIc=
gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438/go.mod h1:GStidGxhaqJhYFW1YpOnLvYCbL2EsM0od7IW4u7+JgU= gitlab.com/nvidia/cloud-native/go-nvlib v0.0.0-20230209143738-95328d8c4438/go.mod h1:GStidGxhaqJhYFW1YpOnLvYCbL2EsM0od7IW4u7+JgU=
golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q= golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q=
golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220927170352-d9d178bc13c6 h1:cy1ko5847T/lJ45eyg/7uLprIE/amW5IXxGtEnQdYMI= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.0.0-20220927170352-d9d178bc13c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

View File

@@ -21,13 +21,19 @@ import (
"io" "io"
"os" "os"
"path" "path"
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/pelletier/go-toml" "github.com/pelletier/go-toml"
"github.com/sirupsen/logrus"
) )
const ( const (
configOverride = "XDG_CONFIG_HOME" configOverride = "XDG_CONFIG_HOME"
configFilePath = "nvidia-container-runtime/config.toml" configFilePath = "nvidia-container-runtime/config.toml"
nvidiaContainerRuntimeHookExecutable = "nvidia-container-runtime-hook"
nvidiaContainerRuntimeHookDefaultPath = "/usr/bin/nvidia-container-runtime-hook"
) )
var ( var (
@@ -124,3 +130,41 @@ func getDefaultConfig() *Config {
return &c return &c
} }
// ResolveNVIDIAContainerRuntimeHookPath resolves the path the nvidia-container-runtime-hook binary.
func ResolveNVIDIAContainerRuntimeHookPath(logger *logrus.Logger, nvidiaContainerRuntimeHookPath string) string {
return resolveWithDefault(
logger,
"NVIDIA Container Runtime Hook",
nvidiaContainerRuntimeHookPath,
nvidiaContainerRuntimeHookDefaultPath,
)
}
// resolveWithDefault resolves the path to the specified binary.
// If an absolute path is specified, it is used directly without searching for the binary.
// If the binary cannot be found in the path, the specified default is used instead.
func resolveWithDefault(logger *logrus.Logger, label string, path string, defaultPath string) string {
if filepath.IsAbs(path) {
logger.Debugf("Using specified %v path %v", label, path)
return path
}
if path == "" {
path = filepath.Base(defaultPath)
}
logger.Debugf("Locating %v as %v", label, path)
lookup := lookup.NewExecutableLocator(logger, "")
resolvedPath := defaultPath
targets, err := lookup.Locate(path)
if err != nil {
logger.Warnf("Failed to locate %v: %v", path, err)
} else {
logger.Debugf("Found %v candidates: %v", path, targets)
resolvedPath = targets[0]
}
logger.Debugf("Using %v path %v", label, path)
return resolvedPath
}

View File

@@ -76,6 +76,9 @@ func TestGetConfig(t *testing.T) {
}, },
}, },
}, },
NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{
Path: "nvidia-container-runtime-hook",
},
NVIDIACTKConfig: CTKConfig{ NVIDIACTKConfig: CTKConfig{
Path: "nvidia-ctk", Path: "nvidia-ctk",
}, },
@@ -95,6 +98,7 @@ func TestGetConfig(t *testing.T) {
"nvidia-container-runtime.modes.cdi.default-kind = \"example.vendor.com/device\"", "nvidia-container-runtime.modes.cdi.default-kind = \"example.vendor.com/device\"",
"nvidia-container-runtime.modes.cdi.annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]", "nvidia-container-runtime.modes.cdi.annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]",
"nvidia-container-runtime.modes.csv.mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"", "nvidia-container-runtime.modes.csv.mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"",
"nvidia-container-runtime-hook.path = \"/foo/bar/nvidia-container-runtime-hook\"",
"nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"", "nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"",
}, },
expectedConfig: &Config{ expectedConfig: &Config{
@@ -120,6 +124,9 @@ func TestGetConfig(t *testing.T) {
}, },
}, },
}, },
NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{
Path: "/foo/bar/nvidia-container-runtime-hook",
},
NVIDIACTKConfig: CTKConfig{ NVIDIACTKConfig: CTKConfig{
Path: "/foo/bar/nvidia-ctk", Path: "/foo/bar/nvidia-ctk",
}, },
@@ -143,6 +150,8 @@ func TestGetConfig(t *testing.T) {
"annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]", "annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]",
"[nvidia-container-runtime.modes.csv]", "[nvidia-container-runtime.modes.csv]",
"mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"", "mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"",
"[nvidia-container-runtime-hook]",
"path = \"/foo/bar/nvidia-container-runtime-hook\"",
"[nvidia-ctk]", "[nvidia-ctk]",
"path = \"/foo/bar/nvidia-ctk\"", "path = \"/foo/bar/nvidia-ctk\"",
}, },
@@ -169,6 +178,9 @@ func TestGetConfig(t *testing.T) {
}, },
}, },
}, },
NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{
Path: "/foo/bar/nvidia-container-runtime-hook",
},
NVIDIACTKConfig: CTKConfig{ NVIDIACTKConfig: CTKConfig{
Path: "/foo/bar/nvidia-ctk", Path: "/foo/bar/nvidia-ctk",
}, },

View File

@@ -24,6 +24,9 @@ import (
// RuntimeHookConfig stores the config options for the NVIDIA Container Runtime // RuntimeHookConfig stores the config options for the NVIDIA Container Runtime
type RuntimeHookConfig struct { type RuntimeHookConfig struct {
// Path specifies the path to the NVIDIA Container Runtime hook binary.
// If an executable name is specified, this will be resolved in the path.
Path string `toml:"path"`
// SkipModeDetection disables the mode check for the runtime hook. // SkipModeDetection disables the mode check for the runtime hook.
SkipModeDetection bool `toml:"skip-mode-detection"` SkipModeDetection bool `toml:"skip-mode-detection"`
} }
@@ -55,6 +58,7 @@ func getRuntimeHookConfigFrom(toml *toml.Tree) (*RuntimeHookConfig, error) {
// GetDefaultRuntimeHookConfig defines the default values for the config // GetDefaultRuntimeHookConfig defines the default values for the config
func GetDefaultRuntimeHookConfig() *RuntimeHookConfig { func GetDefaultRuntimeHookConfig() *RuntimeHookConfig {
c := RuntimeHookConfig{ c := RuntimeHookConfig{
Path: NVIDIAContainerRuntimeHookExecutable,
SkipModeDetection: false, SkipModeDetection: false,
} }

View File

@@ -13,25 +13,25 @@ var _ Discover = &DiscoverMock{}
// DiscoverMock is a mock implementation of Discover. // DiscoverMock is a mock implementation of Discover.
// //
// func TestSomethingThatUsesDiscover(t *testing.T) { // func TestSomethingThatUsesDiscover(t *testing.T) {
// //
// // make and configure a mocked Discover // // make and configure a mocked Discover
// mockedDiscover := &DiscoverMock{ // mockedDiscover := &DiscoverMock{
// DevicesFunc: func() ([]Device, error) { // DevicesFunc: func() ([]Device, error) {
// panic("mock out the Devices method") // panic("mock out the Devices method")
// }, // },
// HooksFunc: func() ([]Hook, error) { // HooksFunc: func() ([]Hook, error) {
// panic("mock out the Hooks method") // panic("mock out the Hooks method")
// }, // },
// MountsFunc: func() ([]Mount, error) { // MountsFunc: func() ([]Mount, error) {
// panic("mock out the Mounts method") // panic("mock out the Mounts method")
// }, // },
// } // }
// //
// // use mockedDiscover in code that requires Discover // // use mockedDiscover in code that requires Discover
// // and then make assertions. // // and then make assertions.
// //
// } // }
type DiscoverMock struct { type DiscoverMock struct {
// DevicesFunc mocks the Devices method. // DevicesFunc mocks the Devices method.
DevicesFunc func() ([]Device, error) DevicesFunc func() ([]Device, error)
@@ -78,7 +78,8 @@ func (mock *DiscoverMock) Devices() ([]Device, error) {
// DevicesCalls gets all the calls that were made to Devices. // DevicesCalls gets all the calls that were made to Devices.
// Check the length with: // Check the length with:
// len(mockedDiscover.DevicesCalls()) //
// len(mockedDiscover.DevicesCalls())
func (mock *DiscoverMock) DevicesCalls() []struct { func (mock *DiscoverMock) DevicesCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -108,7 +109,8 @@ func (mock *DiscoverMock) Hooks() ([]Hook, error) {
// HooksCalls gets all the calls that were made to Hooks. // HooksCalls gets all the calls that were made to Hooks.
// Check the length with: // Check the length with:
// len(mockedDiscover.HooksCalls()) //
// len(mockedDiscover.HooksCalls())
func (mock *DiscoverMock) HooksCalls() []struct { func (mock *DiscoverMock) HooksCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -138,7 +140,8 @@ func (mock *DiscoverMock) Mounts() ([]Mount, error) {
// MountsCalls gets all the calls that were made to Mounts. // MountsCalls gets all the calls that were made to Mounts.
// Check the length with: // Check the length with:
// len(mockedDiscover.MountsCalls()) //
// len(mockedDiscover.MountsCalls())
func (mock *DiscoverMock) MountsCalls() []struct { func (mock *DiscoverMock) MountsCalls() []struct {
} { } {
var calls []struct { var calls []struct {

View File

@@ -271,7 +271,7 @@ func newXorgDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath s
libCudaPaths, err := cuda.New( libCudaPaths, err := cuda.New(
cuda.WithLogger(logger), cuda.WithLogger(logger),
cuda.WithDriverRoot(driverRoot), cuda.WithDriverRoot(driverRoot),
).Locate(".*.*.*") ).Locate(".*.*")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to locate libcuda.so: %v", err) return nil, fmt.Errorf("failed to locate libcuda.so: %v", err)
} }

View File

@@ -18,6 +18,7 @@ package devices
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -72,7 +73,14 @@ func (d devices) Get(name Name) (Major, bool) {
// GetNVIDIADevices returns the set of NVIDIA Devices on the machine // GetNVIDIADevices returns the set of NVIDIA Devices on the machine
func GetNVIDIADevices() (Devices, error) { func GetNVIDIADevices() (Devices, error) {
devicesFile, err := os.Open(procDevicesPath) return nvidiaDevices(procDevicesPath)
}
// nvidiaDevices returns the set of NVIDIA Devices from the specified devices file.
// This is useful for testing since we may be testing on a system where `/proc/devices` does
// contain a reference to NVIDIA devices.
func nvidiaDevices(devicesPath string) (Devices, error) {
devicesFile, err := os.Open(devicesPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil, nil return nil, nil
} }
@@ -81,20 +89,28 @@ func GetNVIDIADevices() (Devices, error) {
} }
defer devicesFile.Close() defer devicesFile.Close()
return nvidiaDeviceFrom(devicesFile), nil return nvidiaDeviceFrom(devicesFile)
} }
func nvidiaDeviceFrom(reader io.Reader) devices { var errNoNvidiaDevices = errors.New("no NVIDIA devices found")
func nvidiaDeviceFrom(reader io.Reader) (devices, error) {
allDevices := devicesFrom(reader) allDevices := devicesFrom(reader)
nvidiaDevices := make(devices) nvidiaDevices := make(devices)
var hasNvidiaDevices bool
for n, d := range allDevices { for n, d := range allDevices {
if !strings.HasPrefix(string(n), nvidiaDevicePrefix) { if !strings.HasPrefix(string(n), nvidiaDevicePrefix) {
continue continue
} }
nvidiaDevices[n] = d nvidiaDevices[n] = d
hasNvidiaDevices = true
} }
return nvidiaDevices if !hasNvidiaDevices {
return nil, errNoNvidiaDevices
}
return nvidiaDevices, nil
} }
func devicesFrom(reader io.Reader) devices { func devicesFrom(reader io.Reader) devices {

View File

@@ -45,21 +45,23 @@ func TestNvidiaDevices(t *testing.T) {
func TestProcessDeviceFile(t *testing.T) { func TestProcessDeviceFile(t *testing.T) {
testCases := []struct { testCases := []struct {
lines []string lines []string
expected devices expected devices
expectedError error
}{ }{
{[]string{}, make(devices)}, {lines: []string{}, expectedError: errNoNvidiaDevices},
{[]string{"Not a valid line:"}, make(devices)}, {lines: []string{"Not a valid line:"}, expectedError: errNoNvidiaDevices},
{[]string{"195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, {lines: []string{"195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{[]string{"195 nvidia-frontend", "235 nvidia-caps"}, devices{"nvidia-frontend": 195, "nvidia-caps": 235}}, {lines: []string{"195 nvidia-frontend", "235 nvidia-caps"}, expected: devices{"nvidia-frontend": 195, "nvidia-caps": 235}},
{[]string{" 195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, {lines: []string{" 195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{[]string{"Not a valid line:", "", "195 nvidia-frontend"}, devices{"nvidia-frontend": 195}}, {lines: []string{"Not a valid line:", "", "195 nvidia-frontend"}, expected: devices{"nvidia-frontend": 195}},
{[]string{"195 not-nvidia-frontend"}, make(devices)}, {lines: []string{"195 not-nvidia-frontend"}, expectedError: errNoNvidiaDevices},
} }
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) { t.Run(fmt.Sprintf("testcase %d", i), func(t *testing.T) {
contents := strings.NewReader(strings.Join(tc.lines, "\n")) contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := nvidiaDeviceFrom(contents) d, err := nvidiaDeviceFrom(contents)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expected, d) require.EqualValues(t, tc.expected, d)
}) })

View File

@@ -13,29 +13,23 @@ var _ Locator = &LocatorMock{}
// LocatorMock is a mock implementation of Locator. // LocatorMock is a mock implementation of Locator.
// //
// func TestSomethingThatUsesLocator(t *testing.T) { // func TestSomethingThatUsesLocator(t *testing.T) {
// //
// // make and configure a mocked Locator // // make and configure a mocked Locator
// mockedLocator := &LocatorMock{ // mockedLocator := &LocatorMock{
// LocateFunc: func(s string) ([]string, error) { // LocateFunc: func(s string) ([]string, error) {
// panic("mock out the Locate method") // panic("mock out the Locate method")
// }, // },
// RelativeFunc: func(s string) (string, error) { // }
// panic("mock out the Relative method")
// },
// }
// //
// // use mockedLocator in code that requires Locator // // use mockedLocator in code that requires Locator
// // and then make assertions. // // and then make assertions.
// //
// } // }
type LocatorMock struct { type LocatorMock struct {
// LocateFunc mocks the Locate method. // LocateFunc mocks the Locate method.
LocateFunc func(s string) ([]string, error) LocateFunc func(s string) ([]string, error)
// RelativeFunc mocks the Relative method.
RelativeFunc func(s string) (string, error)
// calls tracks calls to the methods. // calls tracks calls to the methods.
calls struct { calls struct {
// Locate holds details about calls to the Locate method. // Locate holds details about calls to the Locate method.
@@ -43,14 +37,8 @@ type LocatorMock struct {
// S is the s argument value. // S is the s argument value.
S string S string
} }
// Relative holds details about calls to the Relative method.
Relative []struct {
// S is the s argument value.
S string
}
} }
lockLocate sync.RWMutex lockLocate sync.RWMutex
lockRelative sync.RWMutex
} }
// Locate calls LocateFunc. // Locate calls LocateFunc.
@@ -75,7 +63,8 @@ func (mock *LocatorMock) Locate(s string) ([]string, error) {
// LocateCalls gets all the calls that were made to Locate. // LocateCalls gets all the calls that were made to Locate.
// Check the length with: // Check the length with:
// len(mockedLocator.LocateCalls()) //
// len(mockedLocator.LocateCalls())
func (mock *LocatorMock) LocateCalls() []struct { func (mock *LocatorMock) LocateCalls() []struct {
S string S string
} { } {
@@ -87,38 +76,3 @@ func (mock *LocatorMock) LocateCalls() []struct {
mock.lockLocate.RUnlock() mock.lockLocate.RUnlock()
return calls return calls
} }
// Relative calls RelativeFunc.
func (mock *LocatorMock) Relative(s string) (string, error) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockRelative.Lock()
mock.calls.Relative = append(mock.calls.Relative, callInfo)
mock.lockRelative.Unlock()
if mock.RelativeFunc == nil {
var (
sOut string
errOut error
)
return sOut, errOut
}
return mock.RelativeFunc(s)
}
// RelativeCalls gets all the calls that were made to Relative.
// Check the length with:
// len(mockedLocator.RelativeCalls())
func (mock *LocatorMock) RelativeCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockRelative.RLock()
calls = mock.calls.Relative
mock.lockRelative.RUnlock()
return calls
}

View File

@@ -17,10 +17,8 @@
package modifier package modifier
import ( import (
"fmt" "path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -28,8 +26,11 @@ import (
// NewStableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI // NewStableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI
// spec. The specified logger is used to capture log output. // spec. The specified logger is used to capture log output.
func NewStableRuntimeModifier(logger *logrus.Logger) oci.SpecModifier { func NewStableRuntimeModifier(logger *logrus.Logger, nvidiaContainerRuntimeHookPath string) oci.SpecModifier {
m := stableRuntimeModifier{logger: logger} m := stableRuntimeModifier{
logger: logger,
nvidiaContainerRuntimeHookPath: nvidiaContainerRuntimeHookPath,
}
return &m return &m
} }
@@ -37,7 +38,8 @@ func NewStableRuntimeModifier(logger *logrus.Logger) oci.SpecModifier {
// stableRuntimeModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a // stableRuntimeModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
// prestart hook. If the hook is already present, no modification is made. // prestart hook. If the hook is already present, no modification is made.
type stableRuntimeModifier struct { type stableRuntimeModifier struct {
logger *logrus.Logger logger *logrus.Logger
nvidiaContainerRuntimeHookPath string
} }
// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook // Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
@@ -53,18 +55,9 @@ func (m stableRuntimeModifier) Modify(spec *specs.Spec) error {
} }
} }
// We create a locator and look for the NVIDIA Container Runtime Hook in the path. path := m.nvidiaContainerRuntimeHookPath
candidates, err := lookup.NewExecutableLocator(m.logger, "").Locate(config.NVIDIAContainerRuntimeHookExecutable)
if err != nil {
return fmt.Errorf("failed to locate NVIDIA Container Runtime Hook: %v", err)
}
path := candidates[0]
if len(candidates) > 1 {
m.logger.Debugf("Using %v from multiple NVIDIA Container Runtime Hook candidates: %v", path, candidates)
}
m.logger.Infof("Using prestart hook path: %v", path) m.logger.Infof("Using prestart hook path: %v", path)
args := []string{path} args := []string{filepath.Base(path)}
if spec.Hooks == nil { if spec.Hooks == nil {
spec.Hooks = &specs.Hooks{} spec.Hooks = &specs.Hooks{}
} }

View File

@@ -79,7 +79,7 @@ func TestAddHookModifier(t *testing.T) {
Prestart: []specs.Hook{ Prestart: []specs.Hook{
{ {
Path: testHookPath, Path: testHookPath,
Args: []string{testHookPath, "prestart"}, Args: []string{"nvidia-container-runtime-hook", "prestart"},
}, },
}, },
}, },
@@ -95,7 +95,7 @@ func TestAddHookModifier(t *testing.T) {
Prestart: []specs.Hook{ Prestart: []specs.Hook{
{ {
Path: testHookPath, Path: testHookPath,
Args: []string{testHookPath, "prestart"}, Args: []string{"nvidia-container-runtime-hook", "prestart"},
}, },
}, },
}, },
@@ -141,7 +141,7 @@ func TestAddHookModifier(t *testing.T) {
}, },
{ {
Path: testHookPath, Path: testHookPath,
Args: []string{testHookPath, "prestart"}, Args: []string{"nvidia-container-runtime-hook", "prestart"},
}, },
}, },
}, },
@@ -154,7 +154,7 @@ func TestAddHookModifier(t *testing.T) {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
m := NewStableRuntimeModifier(logger) m := NewStableRuntimeModifier(logger, testHookPath)
err := m.Modify(&tc.spec) err := m.Modify(&tc.spec)
if tc.expectedError != nil { if tc.expectedError != nil {

View File

@@ -13,19 +13,19 @@ var _ Runtime = &RuntimeMock{}
// RuntimeMock is a mock implementation of Runtime. // RuntimeMock is a mock implementation of Runtime.
// //
// func TestSomethingThatUsesRuntime(t *testing.T) { // func TestSomethingThatUsesRuntime(t *testing.T) {
// //
// // make and configure a mocked Runtime // // make and configure a mocked Runtime
// mockedRuntime := &RuntimeMock{ // mockedRuntime := &RuntimeMock{
// ExecFunc: func(strings []string) error { // ExecFunc: func(strings []string) error {
// panic("mock out the Exec method") // panic("mock out the Exec method")
// }, // },
// } // }
// //
// // use mockedRuntime in code that requires Runtime // // use mockedRuntime in code that requires Runtime
// // and then make assertions. // // and then make assertions.
// //
// } // }
type RuntimeMock struct { type RuntimeMock struct {
// ExecFunc mocks the Exec method. // ExecFunc mocks the Exec method.
ExecFunc func(strings []string) error ExecFunc func(strings []string) error
@@ -62,7 +62,8 @@ func (mock *RuntimeMock) Exec(strings []string) error {
// ExecCalls gets all the calls that were made to Exec. // ExecCalls gets all the calls that were made to Exec.
// Check the length with: // Check the length with:
// len(mockedRuntime.ExecCalls()) //
// len(mockedRuntime.ExecCalls())
func (mock *RuntimeMock) ExecCalls() []struct { func (mock *RuntimeMock) ExecCalls() []struct {
Strings []string Strings []string
} { } {

View File

@@ -30,8 +30,9 @@ type SpecModifier interface {
Modify(*specs.Spec) error Modify(*specs.Spec) error
} }
//go:generate moq -stub -out spec_mock.go . Spec
// Spec defines the operations to be performed on an OCI specification // Spec defines the operations to be performed on an OCI specification
//
//go:generate moq -stub -out spec_mock.go . Spec
type Spec interface { type Spec interface {
Load() (*specs.Spec, error) Load() (*specs.Spec, error)
Flush() error Flush() error

View File

@@ -14,28 +14,28 @@ var _ Spec = &SpecMock{}
// SpecMock is a mock implementation of Spec. // SpecMock is a mock implementation of Spec.
// //
// func TestSomethingThatUsesSpec(t *testing.T) { // func TestSomethingThatUsesSpec(t *testing.T) {
// //
// // make and configure a mocked Spec // // make and configure a mocked Spec
// mockedSpec := &SpecMock{ // mockedSpec := &SpecMock{
// FlushFunc: func() error { // FlushFunc: func() error {
// panic("mock out the Flush method") // panic("mock out the Flush method")
// }, // },
// LoadFunc: func() (*specs.Spec, error) { // LoadFunc: func() (*specs.Spec, error) {
// panic("mock out the Load method") // panic("mock out the Load method")
// }, // },
// LookupEnvFunc: func(s string) (string, bool) { // LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method") // panic("mock out the LookupEnv method")
// }, // },
// ModifyFunc: func(specModifier SpecModifier) error { // ModifyFunc: func(specModifier SpecModifier) error {
// panic("mock out the Modify method") // panic("mock out the Modify method")
// }, // },
// } // }
// //
// // use mockedSpec in code that requires Spec // // use mockedSpec in code that requires Spec
// // and then make assertions. // // and then make assertions.
// //
// } // }
type SpecMock struct { type SpecMock struct {
// FlushFunc mocks the Flush method. // FlushFunc mocks the Flush method.
FlushFunc func() error FlushFunc func() error
@@ -92,7 +92,8 @@ func (mock *SpecMock) Flush() error {
// FlushCalls gets all the calls that were made to Flush. // FlushCalls gets all the calls that were made to Flush.
// Check the length with: // Check the length with:
// len(mockedSpec.FlushCalls()) //
// len(mockedSpec.FlushCalls())
func (mock *SpecMock) FlushCalls() []struct { func (mock *SpecMock) FlushCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -122,7 +123,8 @@ func (mock *SpecMock) Load() (*specs.Spec, error) {
// LoadCalls gets all the calls that were made to Load. // LoadCalls gets all the calls that were made to Load.
// Check the length with: // Check the length with:
// len(mockedSpec.LoadCalls()) //
// len(mockedSpec.LoadCalls())
func (mock *SpecMock) LoadCalls() []struct { func (mock *SpecMock) LoadCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -155,7 +157,8 @@ func (mock *SpecMock) LookupEnv(s string) (string, bool) {
// LookupEnvCalls gets all the calls that were made to LookupEnv. // LookupEnvCalls gets all the calls that were made to LookupEnv.
// Check the length with: // Check the length with:
// len(mockedSpec.LookupEnvCalls()) //
// len(mockedSpec.LookupEnvCalls())
func (mock *SpecMock) LookupEnvCalls() []struct { func (mock *SpecMock) LookupEnvCalls() []struct {
S string S string
} { } {
@@ -189,7 +192,8 @@ func (mock *SpecMock) Modify(specModifier SpecModifier) error {
// ModifyCalls gets all the calls that were made to Modify. // ModifyCalls gets all the calls that were made to Modify.
// Check the length with: // Check the length with:
// len(mockedSpec.ModifyCalls()) //
// len(mockedSpec.ModifyCalls())
func (mock *SpecMock) ModifyCalls() []struct { func (mock *SpecMock) ModifyCalls() []struct {
SpecModifier SpecModifier SpecModifier SpecModifier
} { } {

View File

@@ -13,22 +13,22 @@ var _ Constraint = &ConstraintMock{}
// ConstraintMock is a mock implementation of Constraint. // ConstraintMock is a mock implementation of Constraint.
// //
// func TestSomethingThatUsesConstraint(t *testing.T) { // func TestSomethingThatUsesConstraint(t *testing.T) {
// //
// // make and configure a mocked Constraint // // make and configure a mocked Constraint
// mockedConstraint := &ConstraintMock{ // mockedConstraint := &ConstraintMock{
// AssertFunc: func() error { // AssertFunc: func() error {
// panic("mock out the Assert method") // panic("mock out the Assert method")
// }, // },
// StringFunc: func() string { // StringFunc: func() string {
// panic("mock out the String method") // panic("mock out the String method")
// }, // },
// } // }
// //
// // use mockedConstraint in code that requires Constraint // // use mockedConstraint in code that requires Constraint
// // and then make assertions. // // and then make assertions.
// //
// } // }
type ConstraintMock struct { type ConstraintMock struct {
// AssertFunc mocks the Assert method. // AssertFunc mocks the Assert method.
AssertFunc func() error AssertFunc func() error
@@ -67,7 +67,8 @@ func (mock *ConstraintMock) Assert() error {
// AssertCalls gets all the calls that were made to Assert. // AssertCalls gets all the calls that were made to Assert.
// Check the length with: // Check the length with:
// len(mockedConstraint.AssertCalls()) //
// len(mockedConstraint.AssertCalls())
func (mock *ConstraintMock) AssertCalls() []struct { func (mock *ConstraintMock) AssertCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -96,7 +97,8 @@ func (mock *ConstraintMock) String() string {
// StringCalls gets all the calls that were made to String. // StringCalls gets all the calls that were made to String.
// Check the length with: // Check the length with:
// len(mockedConstraint.StringCalls()) //
// len(mockedConstraint.StringCalls())
func (mock *ConstraintMock) StringCalls() []struct { func (mock *ConstraintMock) StringCalls() []struct {
} { } {
var calls []struct { var calls []struct {

View File

@@ -16,8 +16,9 @@
package constraints package constraints
//go:generate moq -stub -out constraint_mock.go . Constraint
// Constraint represents a constraint that is to be evaluated // Constraint represents a constraint that is to be evaluated
//
//go:generate moq -stub -out constraint_mock.go . Constraint
type Constraint interface { type Constraint interface {
String() string String() string
Assert() error Assert() error

View File

@@ -23,8 +23,9 @@ import (
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
) )
//go:generate moq -stub -out property_mock.go . Property
// Property represents a property that is used to check requirements // Property represents a property that is used to check requirements
//
//go:generate moq -stub -out property_mock.go . Property
type Property interface { type Property interface {
Name() string Name() string
Value() (string, error) Value() (string, error)

View File

@@ -13,31 +13,31 @@ var _ Property = &PropertyMock{}
// PropertyMock is a mock implementation of Property. // PropertyMock is a mock implementation of Property.
// //
// func TestSomethingThatUsesProperty(t *testing.T) { // func TestSomethingThatUsesProperty(t *testing.T) {
// //
// // make and configure a mocked Property // // make and configure a mocked Property
// mockedProperty := &PropertyMock{ // mockedProperty := &PropertyMock{
// CompareToFunc: func(s string) (int, error) { // CompareToFunc: func(s string) (int, error) {
// panic("mock out the CompareTo method") // panic("mock out the CompareTo method")
// }, // },
// NameFunc: func() string { // NameFunc: func() string {
// panic("mock out the Name method") // panic("mock out the Name method")
// }, // },
// StringFunc: func() string { // StringFunc: func() string {
// panic("mock out the String method") // panic("mock out the String method")
// }, // },
// ValidateFunc: func(s string) error { // ValidateFunc: func(s string) error {
// panic("mock out the Validate method") // panic("mock out the Validate method")
// }, // },
// ValueFunc: func() (string, error) { // ValueFunc: func() (string, error) {
// panic("mock out the Value method") // panic("mock out the Value method")
// }, // },
// } // }
// //
// // use mockedProperty in code that requires Property // // use mockedProperty in code that requires Property
// // and then make assertions. // // and then make assertions.
// //
// } // }
type PropertyMock struct { type PropertyMock struct {
// CompareToFunc mocks the CompareTo method. // CompareToFunc mocks the CompareTo method.
CompareToFunc func(s string) (int, error) CompareToFunc func(s string) (int, error)
@@ -105,7 +105,8 @@ func (mock *PropertyMock) CompareTo(s string) (int, error) {
// CompareToCalls gets all the calls that were made to CompareTo. // CompareToCalls gets all the calls that were made to CompareTo.
// Check the length with: // Check the length with:
// len(mockedProperty.CompareToCalls()) //
// len(mockedProperty.CompareToCalls())
func (mock *PropertyMock) CompareToCalls() []struct { func (mock *PropertyMock) CompareToCalls() []struct {
S string S string
} { } {
@@ -136,7 +137,8 @@ func (mock *PropertyMock) Name() string {
// NameCalls gets all the calls that were made to Name. // NameCalls gets all the calls that were made to Name.
// Check the length with: // Check the length with:
// len(mockedProperty.NameCalls()) //
// len(mockedProperty.NameCalls())
func (mock *PropertyMock) NameCalls() []struct { func (mock *PropertyMock) NameCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -165,7 +167,8 @@ func (mock *PropertyMock) String() string {
// StringCalls gets all the calls that were made to String. // StringCalls gets all the calls that were made to String.
// Check the length with: // Check the length with:
// len(mockedProperty.StringCalls()) //
// len(mockedProperty.StringCalls())
func (mock *PropertyMock) StringCalls() []struct { func (mock *PropertyMock) StringCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -197,7 +200,8 @@ func (mock *PropertyMock) Validate(s string) error {
// ValidateCalls gets all the calls that were made to Validate. // ValidateCalls gets all the calls that were made to Validate.
// Check the length with: // Check the length with:
// len(mockedProperty.ValidateCalls()) //
// len(mockedProperty.ValidateCalls())
func (mock *PropertyMock) ValidateCalls() []struct { func (mock *PropertyMock) ValidateCalls() []struct {
S string S string
} { } {
@@ -229,7 +233,8 @@ func (mock *PropertyMock) Value() (string, error) {
// ValueCalls gets all the calls that were made to Value. // ValueCalls gets all the calls that were made to Value.
// Check the length with: // Check the length with:
// len(mockedProperty.ValueCalls()) //
// len(mockedProperty.ValueCalls())
func (mock *PropertyMock) ValueCalls() []struct { func (mock *PropertyMock) ValueCalls() []struct {
} { } {
var calls []struct { var calls []struct {

View File

@@ -17,6 +17,7 @@
package runtime package runtime
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -43,7 +44,7 @@ func NewLogger() *Logger {
} }
// Update constructs a Logger with a preddefined formatter // Update constructs a Logger with a preddefined formatter
func (l *Logger) Update(filename string, logLevel string, argv []string) error { func (l *Logger) Update(filename string, logLevel string, argv []string) {
configFromArgs := parseArgs(argv) configFromArgs := parseArgs(argv)
@@ -61,7 +62,7 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error {
if !configFromArgs.version { if !configFromArgs.version {
configLogFile, err := createLogFile(filename) configLogFile, err := createLogFile(filename)
if err != nil { if err != nil {
return fmt.Errorf("error opening debug log file: %v", err) argLogFileError = errors.Join(argLogFileError, err)
} }
if configLogFile != nil { if configLogFile != nil {
logFiles = append(logFiles, configLogFile) logFiles = append(logFiles, configLogFile)
@@ -71,7 +72,7 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error {
if argLogFile != nil { if argLogFile != nil {
logFiles = append(logFiles, argLogFile) logFiles = append(logFiles, argLogFile)
} }
argLogFileError = err argLogFileError = errors.Join(argLogFileError, err)
} }
defer func() { defer func() {
if argLogFileError != nil { if argLogFileError != nil {
@@ -119,8 +120,6 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error {
previousLogger: l.Logger, previousLogger: l.Logger,
logFiles: logFiles, logFiles: logFiles,
} }
return nil
} }
// Reset closes the log file (if any) and resets the logger output to what it // Reset closes the log file (if any) and resets the logger output to what it
@@ -157,11 +156,16 @@ func (l *Logger) Reset() error {
} }
func createLogFile(filename string) (*os.File, error) { func createLogFile(filename string) (*os.File, error) {
if filename != "" && filename != os.DevNull { if filename == "" || filename == os.DevNull {
return os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) return nil, nil
} }
if dir := filepath.Dir(filepath.Clean(filename)); dir != "." {
return nil, nil err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, err
}
}
return os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
} }
type loggerConfig struct { type loggerConfig struct {

View File

@@ -44,18 +44,11 @@ func (r rt) Run(argv []string) (rerr error) {
if err != nil { if err != nil {
return fmt.Errorf("error loading config: %v", err) return fmt.Errorf("error loading config: %v", err)
} }
if r.modeOverride != "" { r.logger.Update(
cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride
}
err = r.logger.Update(
cfg.NVIDIAContainerRuntimeConfig.DebugFilePath, cfg.NVIDIAContainerRuntimeConfig.DebugFilePath,
cfg.NVIDIAContainerRuntimeConfig.LogLevel, cfg.NVIDIAContainerRuntimeConfig.LogLevel,
argv, argv,
) )
if err != nil {
return fmt.Errorf("failed to set up logger: %v", err)
}
defer func() { defer func() {
if rerr != nil { if rerr != nil {
r.logger.Errorf("%v", rerr) r.logger.Errorf("%v", rerr)
@@ -63,6 +56,13 @@ func (r rt) Run(argv []string) (rerr error) {
r.logger.Reset() r.logger.Reset()
}() }()
// We apply some config updates here to ensure that the config is valid in
// all cases.
if r.modeOverride != "" {
cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride
}
cfg.NVIDIAContainerRuntimeHookConfig.Path = config.ResolveNVIDIAContainerRuntimeHookPath(r.logger.Logger, cfg.NVIDIAContainerRuntimeHookConfig.Path)
// Print the config to the output. // Print the config to the output.
configJSON, err := json.MarshalIndent(cfg, "", " ") configJSON, err := json.MarshalIndent(cfg, "", " ")
if err == nil { if err == nil {

View File

@@ -61,10 +61,15 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) {
modeModifier, err := newModeModifier(logger, cfg, ociSpec, argv) mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode)
modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, argv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// For CDI mode we make no additional modifications.
if mode == "cdi" {
return modeModifier, nil
}
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, ociSpec) graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, ociSpec)
if err != nil { if err != nil {
@@ -96,10 +101,10 @@ func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec
return modifiers, nil return modifiers, nil
} }
func newModeModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { func newModeModifier(logger *logrus.Logger, mode string, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) {
switch info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) { switch mode {
case "legacy": case "legacy":
return modifier.NewStableRuntimeModifier(logger), nil return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil
case "csv": case "csv":
return modifier.NewCSVModifier(logger, cfg, ociSpec) return modifier.NewCSVModifier(logger, cfg, ociSpec)
case "cdi": case "cdi":

View File

@@ -0,0 +1,154 @@
/**
# Copyright (c) NVIDIA CORPORATIOm. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
"github.com/sirupsen/logrus"
)
var errInvalidDeviceNode = errors.New("invalid device node")
// Interface provides a set of utilities for interacting with NVIDIA devices on the system.
type Interface struct {
devices.Devices
logger *logrus.Logger
dryRun bool
// devRoot is the root directory where device nodes are expected to exist.
devRoot string
mknoder
}
// New constructs a new Interface struct with the specified options.
func New(opts ...Option) (*Interface, error) {
i := &Interface{}
for _, opt := range opts {
opt(i)
}
if i.logger == nil {
i.logger = logrus.StandardLogger()
}
if i.devRoot == "" {
i.devRoot = "/"
}
if i.Devices == nil {
devices, err := devices.GetNVIDIADevices()
if err != nil {
return nil, fmt.Errorf("failed to create devices info: %v", err)
}
i.Devices = devices
}
if i.dryRun {
i.mknoder = &mknodLogger{i.logger}
} else {
i.mknoder = &mknodUnix{}
}
return i, nil
}
// CreateNVIDIAControlDevices creates the NVIDIA control device nodes at the configured devRoot.
func (m *Interface) CreateNVIDIAControlDevices() error {
controlNodes := []string{"nvidiactl", "nvidia-modeset", "nvidia-uvm", "nvidia-uvm-tools"}
for _, node := range controlNodes {
err := m.CreateNVIDIADevice(node)
if err != nil {
return fmt.Errorf("failed to create device node %s: %w", node, err)
}
}
return nil
}
// CreateNVIDIADevice creates the specified NVIDIA device node at the configured devRoot.
func (m *Interface) CreateNVIDIADevice(node string) error {
node = filepath.Base(node)
if !strings.HasPrefix(node, "nvidia") {
return fmt.Errorf("invalid device node %q: %w", node, errInvalidDeviceNode)
}
major, err := m.Major(node)
if err != nil {
return fmt.Errorf("failed to determine major: %w", err)
}
minor, err := m.Minor(node)
if err != nil {
return fmt.Errorf("failed to determine minor: %w", err)
}
return m.createDeviceNode(filepath.Join("dev", node), int(major), int(minor))
}
// createDeviceNode creates the specified device node with the require major and minor numbers.
// If a devRoot is configured, this is prepended to the path.
func (m *Interface) createDeviceNode(path string, major int, minor int) error {
path = filepath.Join(m.devRoot, path)
if _, err := os.Stat(path); err == nil {
m.logger.Infof("Skipping: %s already exists", path)
return nil
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to stat %s: %v", path, err)
}
return m.Mknode(path, major, minor)
}
// Major returns the major number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (m *Interface) Major(node string) (int64, error) {
var valid bool
var major devices.Major
switch node {
case "nvidia-uvm", "nvidia-uvm-tools":
major, valid = m.Get(devices.NVIDIAUVM)
case "nvidia-modeset", "nvidiactl":
major, valid = m.Get(devices.NVIDIAGPU)
}
if valid {
return int64(major), nil
}
return 0, errInvalidDeviceNode
}
// Minor returns the minor number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (m *Interface) Minor(node string) (int64, error) {
switch node {
case "nvidia-modeset":
return devices.NVIDIAModesetMinor, nil
case "nvidia-uvm-tools":
return devices.NVIDIAUVMToolsMinor, nil
case "nvidia-uvm":
return devices.NVIDIAUVMMinor, nil
case "nvidiactl":
return devices.NVIDIACTLMinor, nil
}
return 0, errInvalidDeviceNode
}

View File

@@ -0,0 +1,133 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"errors"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestCreateControlDevices(t *testing.T) {
logger, _ := testlog.NewNullLogger()
nvidiaDevices := &devices.DevicesMock{
GetFunc: func(name devices.Name) (devices.Major, bool) {
devices := map[devices.Name]devices.Major{
"nvidia-frontend": 195,
"nvidia-uvm": 243,
}
return devices[name], true
},
}
mknodeError := errors.New("mknode error")
testCases := []struct {
description string
root string
devices devices.Devices
mknodeError error
expectedError error
expectedCalls []struct {
S string
N1 int
N2 int
}
}{
{
description: "no root specified",
root: "",
devices: nvidiaDevices,
mknodeError: nil,
expectedCalls: []struct {
S string
N1 int
N2 int
}{
{"/dev/nvidiactl", 195, 255},
{"/dev/nvidia-modeset", 195, 254},
{"/dev/nvidia-uvm", 243, 0},
{"/dev/nvidia-uvm-tools", 243, 1},
},
},
{
description: "some root specified",
root: "/some/root",
devices: nvidiaDevices,
mknodeError: nil,
expectedCalls: []struct {
S string
N1 int
N2 int
}{
{"/some/root/dev/nvidiactl", 195, 255},
{"/some/root/dev/nvidia-modeset", 195, 254},
{"/some/root/dev/nvidia-uvm", 243, 0},
{"/some/root/dev/nvidia-uvm-tools", 243, 1},
},
},
{
description: "mknod error returns error",
devices: nvidiaDevices,
mknodeError: mknodeError,
expectedError: mknodeError,
// We expect the first call to this to fail, and the rest to be skipped
expectedCalls: []struct {
S string
N1 int
N2 int
}{
{"/dev/nvidiactl", 195, 255},
},
},
{
description: "missing major returns error",
devices: &devices.DevicesMock{
GetFunc: func(name devices.Name) (devices.Major, bool) {
return 0, false
},
},
expectedError: errInvalidDeviceNode,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
mknode := &mknoderMock{
MknodeFunc: func(string, int, int) error {
return tc.mknodeError
},
}
d, _ := New(
WithLogger(logger),
WithDevRoot(tc.root),
WithDevices(tc.devices),
)
d.mknoder = mknode
err := d.CreateNVIDIAControlDevices()
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expectedCalls, mknode.MknodeCalls())
})
}
}

View File

@@ -0,0 +1,46 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
//go:generate moq -stub -out mknod_mock.go . mknoder
type mknoder interface {
Mknode(string, int, int) error
}
type mknodLogger struct {
*logrus.Logger
}
func (m *mknodLogger) Mknode(path string, major, minor int) error {
m.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor)
return nil
}
type mknodUnix struct{}
func (m *mknodUnix) Mknode(path string, major, minor int) error {
err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor))))
if err != nil {
return err
}
return unix.Chmod(path, 0666)
}

View File

@@ -0,0 +1,89 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvdevices
import (
"sync"
)
// Ensure, that mknoderMock does implement mknoder.
// If this is not the case, regenerate this file with moq.
var _ mknoder = &mknoderMock{}
// mknoderMock is a mock implementation of mknoder.
//
// func TestSomethingThatUsesmknoder(t *testing.T) {
//
// // make and configure a mocked mknoder
// mockedmknoder := &mknoderMock{
// MknodeFunc: func(s string, n1 int, n2 int) error {
// panic("mock out the Mknode method")
// },
// }
//
// // use mockedmknoder in code that requires mknoder
// // and then make assertions.
//
// }
type mknoderMock struct {
// MknodeFunc mocks the Mknode method.
MknodeFunc func(s string, n1 int, n2 int) error
// calls tracks calls to the methods.
calls struct {
// Mknode holds details about calls to the Mknode method.
Mknode []struct {
// S is the s argument value.
S string
// N1 is the n1 argument value.
N1 int
// N2 is the n2 argument value.
N2 int
}
}
lockMknode sync.RWMutex
}
// Mknode calls MknodeFunc.
func (mock *mknoderMock) Mknode(s string, n1 int, n2 int) error {
callInfo := struct {
S string
N1 int
N2 int
}{
S: s,
N1: n1,
N2: n2,
}
mock.lockMknode.Lock()
mock.calls.Mknode = append(mock.calls.Mknode, callInfo)
mock.lockMknode.Unlock()
if mock.MknodeFunc == nil {
var (
errOut error
)
return errOut
}
return mock.MknodeFunc(s, n1, n2)
}
// MknodeCalls gets all the calls that were made to Mknode.
// Check the length with:
//
// len(mockedmknoder.MknodeCalls())
func (mock *mknoderMock) MknodeCalls() []struct {
S string
N1 int
N2 int
} {
var calls []struct {
S string
N1 int
N2 int
}
mock.lockMknode.RLock()
calls = mock.calls.Mknode
mock.lockMknode.RUnlock()
return calls
}

View File

@@ -0,0 +1,53 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvdevices
import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
"github.com/sirupsen/logrus"
)
// Option is a function that sets an option on the Interface struct.
type Option func(*Interface)
// WithDryRun sets the dry run option for the Interface struct.
func WithDryRun(dryRun bool) Option {
return func(i *Interface) {
i.dryRun = dryRun
}
}
// WithLogger sets the logger for the Interface struct.
func WithLogger(logger *logrus.Logger) Option {
return func(i *Interface) {
i.logger = logger
}
}
// WithDevRoot sets the root directory for the NVIDIA device nodes.
func WithDevRoot(devRoot string) Option {
return func(i *Interface) {
i.devRoot = devRoot
}
}
// WithDevices sets the devices for the Interface struct.
func WithDevices(devices devices.Devices) Option {
return func(i *Interface) {
i.Devices = devices
}
}

View File

@@ -0,0 +1,49 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvmodules
import (
"fmt"
"os/exec"
"strings"
"github.com/sirupsen/logrus"
)
//go:generate moq -stub -out cmd_mock.go . cmder
type cmder interface {
// Run executes the command and returns the stdout, stderr, and an error if any
Run(string, ...string) error
}
type cmderLogger struct {
*logrus.Logger
}
func (c *cmderLogger) Run(cmd string, args ...string) error {
c.Infof("Running: %v %v", cmd, strings.Join(args, " "))
return nil
}
type cmderExec struct{}
func (c *cmderExec) Run(cmd string, args ...string) error {
if output, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
return fmt.Errorf("%w; output=%v", err, string(output))
}
return nil
}

View File

@@ -0,0 +1,83 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvmodules
import (
"sync"
)
// Ensure, that cmderMock does implement cmder.
// If this is not the case, regenerate this file with moq.
var _ cmder = &cmderMock{}
// cmderMock is a mock implementation of cmder.
//
// func TestSomethingThatUsescmder(t *testing.T) {
//
// // make and configure a mocked cmder
// mockedcmder := &cmderMock{
// RunFunc: func(s string, strings ...string) error {
// panic("mock out the Run method")
// },
// }
//
// // use mockedcmder in code that requires cmder
// // and then make assertions.
//
// }
type cmderMock struct {
// RunFunc mocks the Run method.
RunFunc func(s string, strings ...string) error
// calls tracks calls to the methods.
calls struct {
// Run holds details about calls to the Run method.
Run []struct {
// S is the s argument value.
S string
// Strings is the strings argument value.
Strings []string
}
}
lockRun sync.RWMutex
}
// Run calls RunFunc.
func (mock *cmderMock) Run(s string, strings ...string) error {
callInfo := struct {
S string
Strings []string
}{
S: s,
Strings: strings,
}
mock.lockRun.Lock()
mock.calls.Run = append(mock.calls.Run, callInfo)
mock.lockRun.Unlock()
if mock.RunFunc == nil {
var (
errOut error
)
return errOut
}
return mock.RunFunc(s, strings...)
}
// RunCalls gets all the calls that were made to Run.
// Check the length with:
//
// len(mockedcmder.RunCalls())
func (mock *cmderMock) RunCalls() []struct {
S string
Strings []string
} {
var calls []struct {
S string
Strings []string
}
mock.lockRun.RLock()
calls = mock.calls.Run
mock.lockRun.RUnlock()
return calls
}

View File

@@ -0,0 +1,93 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvmodules
import (
"fmt"
"strings"
"github.com/sirupsen/logrus"
)
// Interface provides a set of utilities for interacting with NVIDIA modules on the system.
type Interface struct {
logger *logrus.Logger
dryRun bool
root string
cmder
}
// New constructs a new Interface struct with the specified options.
func New(opts ...Option) *Interface {
m := &Interface{}
for _, opt := range opts {
opt(m)
}
if m.logger == nil {
m.logger = logrus.StandardLogger()
}
if m.root == "" {
m.root = "/"
}
if m.dryRun {
m.cmder = &cmderLogger{m.logger}
} else {
m.cmder = &cmderExec{}
}
return m
}
// LoadAll loads all the NVIDIA kernel modules.
func (m *Interface) LoadAll() error {
modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"}
for _, module := range modules {
if err := m.Load(module); err != nil {
return fmt.Errorf("failed to load module %s: %w", module, err)
}
}
return nil
}
var errInvalidModule = fmt.Errorf("invalid module")
// Load loads the specified NVIDIA kernel module.
// If the root is specified we first chroot into this root.
func (m *Interface) Load(module string) error {
if !strings.HasPrefix(module, "nvidia") {
return errInvalidModule
}
var args []string
if m.root != "/" {
args = append(args, "chroot", m.root)
}
args = append(args, "/sbin/modprobe", module)
m.logger.Debugf("Loading kernel module %s: %v", module, args)
err := m.Run(args[0], args[1:]...)
if err != nil {
m.logger.Debugf("Failed to load kernel module %s: %v", module, err)
return err
}
return nil
}

View File

@@ -0,0 +1,178 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvmodules
import (
"errors"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestLoadAll(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runError := errors.New("run error")
testCases := []struct {
description string
root string
runError error
expectedError error
expectedCalls []struct {
S string
Strings []string
}
}{
{
description: "no root specified",
root: "",
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
{"/sbin/modprobe", []string{"nvidia-uvm"}},
{"/sbin/modprobe", []string{"nvidia-modeset"}},
},
},
{
description: "root causes chroot",
root: "/some/root",
expectedCalls: []struct {
S string
Strings []string
}{
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}},
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-uvm"}},
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia-modeset"}},
},
},
{
description: "run failure is returned",
root: "",
runError: runError,
expectedError: runError,
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
cmder := &cmderMock{
RunFunc: func(cmd string, args ...string) error {
return tc.runError
},
}
m := New(
WithLogger(logger),
WithRoot(tc.root),
)
m.cmder = cmder
err := m.LoadAll()
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expectedCalls, cmder.RunCalls())
})
}
}
func TestLoad(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runError := errors.New("run error")
testCases := []struct {
description string
root string
module string
runError error
expectedError error
expectedCalls []struct {
S string
Strings []string
}
}{
{
description: "no root specified",
root: "",
module: "nvidia",
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
},
},
{
description: "root causes chroot",
root: "/some/root",
module: "nvidia",
expectedCalls: []struct {
S string
Strings []string
}{
{"chroot", []string{"/some/root", "/sbin/modprobe", "nvidia"}},
},
},
{
description: "run failure is returned",
root: "",
module: "nvidia",
runError: runError,
expectedError: runError,
expectedCalls: []struct {
S string
Strings []string
}{
{"/sbin/modprobe", []string{"nvidia"}},
},
},
{
description: "module prefis is checked",
module: "not-nvidia",
expectedError: errInvalidModule,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
cmder := &cmderMock{
RunFunc: func(cmd string, args ...string) error {
return tc.runError
},
}
m := New(
WithLogger(logger),
WithRoot(tc.root),
)
m.cmder = cmder
err := m.Load(tc.module)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expectedCalls, cmder.RunCalls())
})
}
}

View File

@@ -14,23 +14,30 @@
# limitations under the License. # limitations under the License.
**/ **/
package system package nvmodules
import "github.com/sirupsen/logrus" import "github.com/sirupsen/logrus"
// Option is a functional option for the system command // Option is a function that sets an option on the Interface struct.
type Option func(*Interface) type Option func(*Interface)
// WithLogger sets the logger for the system command // WithDryRun sets the dry run option for the Interface struct.
func WithDryRun(dryRun bool) Option {
return func(i *Interface) {
i.dryRun = dryRun
}
}
// WithLogger sets the logger for the Interface struct.
func WithLogger(logger *logrus.Logger) Option { func WithLogger(logger *logrus.Logger) Option {
return func(i *Interface) { return func(i *Interface) {
i.logger = logger i.logger = logger
} }
} }
// WithDryRun sets the dry run flag // WithRoot sets the root directory for the NVIDIA device nodes.
func WithDryRun(dryRun bool) Option { func WithRoot(root string) Option {
return func(i *Interface) { return func(i *Interface) {
i.dryRun = dryRun i.root = root
} }
} }

View File

@@ -1,149 +0,0 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package system
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// Interface is the interface for the system command
type Interface struct {
logger *logrus.Logger
dryRun bool
nvidiaDevices nvidiaDevices
}
// New constructs a system command with the specified options
func New(opts ...Option) (*Interface, error) {
i := &Interface{
logger: logrus.StandardLogger(),
}
for _, opt := range opts {
opt(i)
}
devices, err := devices.GetNVIDIADevices()
if err != nil {
return nil, fmt.Errorf("failed to create devices info: %v", err)
}
i.nvidiaDevices = nvidiaDevices{devices}
return i, nil
}
// CreateNVIDIAControlDeviceNodesAt creates the NVIDIA control device nodes associated with the NVIDIA driver at the specified root.
func (m *Interface) CreateNVIDIAControlDeviceNodesAt(root string) error {
controlNodes := []string{"/dev/nvidiactl", "/dev/nvidia-modeset", "/dev/nvidia-uvm", "/dev/nvidia-uvm-tools"}
for _, node := range controlNodes {
path := filepath.Join(root, node)
err := m.CreateNVIDIADeviceNode(path)
if err != nil {
return fmt.Errorf("failed to create device node %s: %v", path, err)
}
}
return nil
}
// CreateNVIDIADeviceNode creates a specified device node associated with the NVIDIA driver.
func (m *Interface) CreateNVIDIADeviceNode(path string) error {
node := filepath.Base(path)
if !strings.HasPrefix(node, "nvidia") {
return fmt.Errorf("invalid device node %q", node)
}
major, err := m.nvidiaDevices.Major(node)
if err != nil {
return fmt.Errorf("failed to determine major: %v", err)
}
minor, err := m.nvidiaDevices.Minor(node)
if err != nil {
return fmt.Errorf("failed to determine minor: %v", err)
}
return m.createDeviceNode(path, int(major), int(minor))
}
func (m *Interface) createDeviceNode(path string, major int, minor int) error {
if m.dryRun {
m.logger.Infof("Running: mknod --mode=0666 %s c %d %d", path, major, minor)
return nil
}
if _, err := os.Stat(path); err == nil {
m.logger.Infof("Skipping: %s already exists", path)
return nil
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to stat %s: %v", path, err)
}
err := unix.Mknod(path, unix.S_IFCHR, int(unix.Mkdev(uint32(major), uint32(minor))))
if err != nil {
return err
}
return unix.Chmod(path, 0666)
}
type nvidiaDevices struct {
devices.Devices
}
// Major returns the major number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (n *nvidiaDevices) Major(node string) (int64, error) {
var valid bool
var major devices.Major
switch node {
case "nvidia-uvm", "nvidia-uvm-tools":
major, valid = n.Get(devices.NVIDIAUVM)
case "nvidia-modeset", "nvidiactl":
major, valid = n.Get(devices.NVIDIAGPU)
}
if !valid {
return 0, fmt.Errorf("invalid device node %q", node)
}
return int64(major), nil
}
// Minor returns the minor number for the specified NVIDIA device node.
// If the device node is not supported, an error is returned.
func (n *nvidiaDevices) Minor(node string) (int64, error) {
switch node {
case "nvidia-modeset":
return devices.NVIDIAModesetMinor, nil
case "nvidia-uvm-tools":
return devices.NVIDIAUVMToolsMinor, nil
case "nvidia-uvm":
return devices.NVIDIAUVMMinor, nil
case "nvidiactl":
return devices.NVIDIACTLMinor, nil
}
return 0, fmt.Errorf("invalid device node %q", node)
}

View File

@@ -67,7 +67,6 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
if len(searchPaths) > 1 { if len(searchPaths) > 1 {
logger.Warnf("Found multiple driver store paths: %v", searchPaths) logger.Warnf("Found multiple driver store paths: %v", searchPaths)
} }
driverStorePath := searchPaths[0]
searchPaths = append(searchPaths, "/usr/lib/wsl/lib") searchPaths = append(searchPaths, "/usr/lib/wsl/lib")
libraries := discover.NewMounts( libraries := discover.NewMounts(
@@ -83,12 +82,11 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
requiredDriverStoreFiles, requiredDriverStoreFiles,
) )
// On WSL2 the driver store location is used unchanged. symlinkHook := nvidiaSMISimlinkHook{
// For this reason we need to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the driver store. logger: logger,
target := filepath.Join(driverStorePath, "nvidia-smi") mountsFrom: libraries,
link := "/usr/bin/nvidia-smi" nvidiaCTKPath: nvidiaCTKPath,
links := []string{fmt.Sprintf("%s::%s", target, link)} }
symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links)
cfg := &discover.Config{ cfg := &discover.Config{
DriverRoot: driverRoot, DriverRoot: driverRoot,
@@ -104,3 +102,39 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
return d, nil return d, nil
} }
type nvidiaSMISimlinkHook struct {
discover.None
logger *logrus.Logger
mountsFrom discover.Discover
nvidiaCTKPath string
}
// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store.
// On WSL2 the driver store location is used unchanged, for this reason we need
// to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the
// driver store.
func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
mounts, err := m.mountsFrom.Mounts()
if err != nil {
return nil, fmt.Errorf("failed to discover mounts: %w", err)
}
var target string
for _, mount := range mounts {
if filepath.Base(mount.Path) == "nvidia-smi" {
target = mount.Path
break
}
}
if target == "" {
m.logger.Warningf("Failed to find nvidia-smi in mounts: %v", mounts)
return nil, nil
}
link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCTKPath, links)
return symlinkHook.Hooks()
}

View File

@@ -0,0 +1,163 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvcdi
import (
"errors"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/stretchr/testify/require"
testlog "github.com/sirupsen/logrus/hooks/test"
)
func TestNvidiaSMISymlinkHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
errMounts := errors.New("mounts error")
testCases := []struct {
description string
mounts discover.Discover
expectedError error
expectedHooks []discover.Hook
}{
{
description: "mounts error is returned",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
return nil, errMounts
},
},
expectedError: errMounts,
},
{
description: "no mounts returns no hooks",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
return nil, nil
},
},
},
{
description: "no nvidia-smi returns no hooks",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/not-nvidia-smi"},
{Path: "/also-not-nvidia-smi"},
}
return mounts, nil
},
},
},
{
description: "nvidia-smi must be in path",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/not-nvidia-smi", HostPath: "nvidia-smi"},
{Path: "/also-not-nvidia-smi", HostPath: "not-nvidia-smi"},
}
return mounts, nil
},
},
},
{
description: "nvidia-smi returns hook",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "nvidia-smi"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
{
description: "checks basename",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/some/path/nvidia-smi"},
{Path: "/nvidia-smi/but-not"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
{
description: "returns first match",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/some/path/nvidia-smi"},
{Path: "/another/path/nvidia-smi"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
m := nvidiaSMISimlinkHook{
logger: logger,
mountsFrom: tc.mounts,
nvidiaCTKPath: "nvidia-ctk",
}
devices, err := m.Devices()
require.NoError(t, err)
require.Empty(t, devices)
mounts, err := m.Mounts()
require.NoError(t, err)
require.Empty(t, mounts)
hooks, err := m.Hooks()
require.ErrorIs(t, err, tc.expectedError)
require.Equal(t, tc.expectedHooks, hooks)
})
}
}

View File

@@ -88,7 +88,7 @@ func (m *managementlib) getCudaVersion() (string, error) {
libCudaPaths, err := cuda.New( libCudaPaths, err := cuda.New(
cuda.WithLogger(m.logger), cuda.WithLogger(m.logger),
cuda.WithDriverRoot(m.driverRoot), cuda.WithDriverRoot(m.driverRoot),
).Locate(".*.*.*") ).Locate(".*.*")
if err != nil { if err != nil {
return "", fmt.Errorf("failed to locate libcuda.so: %v", err) return "", fmt.Errorf("failed to locate libcuda.so: %v", err)
} }

View File

@@ -18,6 +18,7 @@ package spec
import ( import (
"fmt" "fmt"
"os"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
@@ -33,6 +34,7 @@ type builder struct {
edits specs.ContainerEdits edits specs.ContainerEdits
format string format string
noSimplify bool noSimplify bool
permissions os.FileMode
} }
// newBuilder creates a new spec builder with the supplied options // newBuilder creates a new spec builder with the supplied options
@@ -60,7 +62,9 @@ func newBuilder(opts ...Option) *builder {
if s.format == "" { if s.format == "" {
s.format = FormatYAML s.format = FormatYAML
} }
if s.permissions == 0 {
s.permissions = 0600
}
return s return s
} }
@@ -92,8 +96,9 @@ func (o *builder) Build() (*spec, error) {
} }
s := spec{ s := spec{
Spec: raw, Spec: raw,
format: o.format, format: o.format,
permissions: o.permissions,
} }
return &s, nil return &s, nil
@@ -157,3 +162,10 @@ func WithRawSpec(raw *specs.Spec) Option {
o.raw = raw o.raw = raw
} }
} }
// WithPermissions sets the permissions for the generated spec file
func WithPermissions(permissions os.FileMode) Option {
return func(o *builder) {
o.permissions = permissions
}
}

View File

@@ -28,7 +28,8 @@ import (
type spec struct { type spec struct {
*specs.Spec *specs.Spec
format string format string
permissions os.FileMode
} }
var _ Interface = (*spec)(nil) var _ Interface = (*spec)(nil)
@@ -51,7 +52,15 @@ func (s *spec) Save(path string) error {
cdi.WithSpecDirs(specDir), cdi.WithSpecDirs(specDir),
) )
return registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path)) if err := registry.SpecDB().WriteSpec(s.Raw(), filepath.Base(path)); err != nil {
return fmt.Errorf("failed to write spec: %w", err)
}
if err := os.Chmod(path, s.permissions); err != nil {
return fmt.Errorf("failed to set permissions on spec file: %w", err)
}
return nil
} }
// WriteTo writes the spec to the specified writer. // WriteTo writes the spec to the specified writer.

View File

@@ -46,14 +46,6 @@ else
targets=${all[@]} targets=${all[@]}
fi fi
echo "Updating components"
"${SCRIPTS_DIR}/update-components.sh"
if [[ -n $(git status -s third_party) && ${ALLOW_LOCAL_COMPONENT_CHANGES} != "true" ]]; then
echo "ERROR: Building with local component changes."
echo "Commit pending changes or rerun with ALLOW_LOCAL_COMPONENT_CHANGES='true'"
exit 1
fi
eval $(${SCRIPTS_DIR}/get-component-versions.sh) eval $(${SCRIPTS_DIR}/get-component-versions.sh)

View File

@@ -94,19 +94,6 @@ function extract-all() {
local dist=$1 local dist=$1
echo "Extracting packages for ${dist} from ${PACKAGE_IMAGE}" echo "Extracting packages for ${dist} from ${PACKAGE_IMAGE}"
if [ $dist == "ubuntu18.04" ]; then
set -x
# We need to publish the libnvidia-container0 packages to the kitmaker repository as a once off operation.
# We include the packages here so that these will be added to the archive for the ubuntu18.04 arm64 packages.
mkdir -p "${ARTIFACTS_DIR}/packages/ubuntu18.04/arm64/"
curl -L "https://nvidia.github.io/libnvidia-container/ubuntu18.04/arm64/libnvidia-container0_0.10.0+jetpack_arm64.deb" \
--output "${ARTIFACTS_DIR}/packages/ubuntu18.04/arm64/libnvidia-container0_0.10.0+jetpack_arm64.deb"
curl -L "https://nvidia.github.io/libnvidia-container/ubuntu18.04/arm64/libnvidia-container0_0.11.0+jetpack_arm64.deb" \
--output "${ARTIFACTS_DIR}/packages/ubuntu18.04/arm64/libnvidia-container0_0.11.0+jetpack_arm64.deb"
set +x
fi
# Extract every file for the specified dist-arch combiniation in MANIFEST.txt # Extract every file for the specified dist-arch combiniation in MANIFEST.txt
grep "/${dist}/" "${ARTIFACTS_DIR}/manifest.txt" | while read -r f ; do grep "/${dist}/" "${ARTIFACTS_DIR}/manifest.txt" | while read -r f ; do
package_name="$(basename "$f")" package_name="$(basename "$f")"

View File

@@ -27,7 +27,7 @@ testing::crio::hook_created() {
} }
testing::crio::hook_cleanup() { testing::crio::hook_cleanup() {
testing::docker_run::toolkit::shell 'crio cleanup' testing::docker_run::toolkit::shell 'crio cleanup --nvidia-runtime-dir /run/nvidia/toolkit/'
test -z "$(ls -A "${shared_dir}${CRIO_HOOKS_DIR}")" test -z "$(ls -A "${shared_dir}${CRIO_HOOKS_DIR}")"
} }

View File

@@ -0,0 +1,169 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package container
import (
"fmt"
"os"
"os/exec"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
)
const (
restartModeNone = "none"
restartModeSignal = "signal"
restartModeSystemd = "systemd"
)
// Options defines the shared options for the CLIs to configure containers runtimes.
type Options struct {
Config string
Socket string
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
}
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context, o *Options) error {
if o.RuntimeDir != "" {
logrus.Debug("Runtime directory already set; ignoring arguments")
return nil
}
args := c.Args()
logrus.Infof("Parsing arguments: %v", args.Slice())
if c.NArg() != 1 {
return fmt.Errorf("incorrect number of arguments")
}
o.RuntimeDir = args.Get(0)
logrus.Infof("Successfully parsed arguments")
return nil
}
// Configure applies the options to the specified config
func (o Options) Configure(cfg engine.Interface) error {
err := o.UpdateConfig(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
return o.flush(cfg)
}
// Unconfigure removes the options from the specified config
func (o Options) Unconfigure(cfg engine.Interface) error {
err := o.RevertConfig(cfg)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
return o.flush(cfg)
}
// flush flushes the specified config to disk
func (o Options) flush(cfg engine.Interface) error {
logrus.Infof("Flushing config to %v", o.Config)
n, err := cfg.Save(o.Config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
logrus.Infof("Config file is empty, removed")
}
return nil
}
// UpdateConfig updates the specified config to include the nvidia runtimes
func (o Options) UpdateConfig(cfg engine.Interface) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.RuntimeName),
operator.WithSetAsDefault(o.SetAsDefault),
operator.WithRoot(o.RuntimeDir),
)
for name, runtime := range runtimes {
err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("failed to update runtime %q: %v", name, err)
}
}
return nil
}
// RevertConfig reverts the specified config to remove the nvidia runtimes
func (o Options) RevertConfig(cfg engine.Interface) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.RuntimeName),
operator.WithSetAsDefault(o.SetAsDefault),
operator.WithRoot(o.RuntimeDir),
)
for name := range runtimes {
err := cfg.RemoveRuntime(name)
if err != nil {
return fmt.Errorf("failed to remove runtime %q: %v", name, err)
}
}
return nil
}
// Restart restarts the specified service
func (o Options) Restart(service string, withSignal func(string) error) error {
switch o.RestartMode {
case restartModeNone:
logrus.Warnf("Skipping restart of %v due to --restart-mode=%v", service, o.RestartMode)
return nil
case restartModeSignal:
return withSignal(o.Socket)
case restartModeSystemd:
return o.SystemdRestart(service)
}
return fmt.Errorf("invalid restart mode specified: %v", o.RestartMode)
}
// SystemdRestart restarts the specified service using systemd
func (o Options) SystemdRestart(service string) error {
var args []string
var msg string
if o.HostRootMount != "" {
msg = " on host"
args = append(args, "chroot", o.HostRootMount)
}
args = append(args, "systemctl", "restart", service)
logrus.Infof("Restarting %v%v using systemd: %v", service, msg, args)
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error restarting %v using systemd: %v", service, err)
}
return nil
}

View File

@@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
"github.com/pelletier/go-toml" "github.com/pelletier/go-toml"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -31,7 +32,7 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
testCases := []struct { testCases := []struct {
legacyConfig bool legacyConfig bool
setAsDefault bool setAsDefault bool
runtimeClass string runtimeName string
expectedDefaultRuntimeName interface{} expectedDefaultRuntimeName interface{}
expectedDefaultRuntimeBinary interface{} expectedDefaultRuntimeBinary interface{}
}{ }{
@@ -51,14 +52,14 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
{ {
legacyConfig: true, legacyConfig: true,
setAsDefault: true, setAsDefault: true,
runtimeClass: "NAME", runtimeName: "NAME",
expectedDefaultRuntimeName: nil, expectedDefaultRuntimeName: nil,
expectedDefaultRuntimeBinary: "/test/runtime/dir/nvidia-container-runtime", expectedDefaultRuntimeBinary: "/test/runtime/dir/nvidia-container-runtime",
}, },
{ {
legacyConfig: true, legacyConfig: true,
setAsDefault: true, setAsDefault: true,
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: nil, expectedDefaultRuntimeName: nil,
expectedDefaultRuntimeBinary: "/test/runtime/dir/nvidia-container-runtime.experimental", expectedDefaultRuntimeBinary: "/test/runtime/dir/nvidia-container-runtime.experimental",
}, },
@@ -77,14 +78,14 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
{ {
legacyConfig: false, legacyConfig: false,
setAsDefault: true, setAsDefault: true,
runtimeClass: "NAME", runtimeName: "NAME",
expectedDefaultRuntimeName: "NAME", expectedDefaultRuntimeName: "NAME",
expectedDefaultRuntimeBinary: nil, expectedDefaultRuntimeBinary: nil,
}, },
{ {
legacyConfig: false, legacyConfig: false,
setAsDefault: true, setAsDefault: true,
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: "nvidia-experimental", expectedDefaultRuntimeName: "nvidia-experimental",
expectedDefaultRuntimeBinary: nil, expectedDefaultRuntimeBinary: nil,
}, },
@@ -93,11 +94,13 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
useLegacyConfig: tc.legacyConfig, Options: container.Options{
setAsDefault: tc.setAsDefault, RuntimeName: tc.runtimeName,
runtimeClass: tc.runtimeClass, RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
runtimeType: runtimeType, runtimeType: runtimeType,
runtimeDir: runtimeDir, useLegacyConfig: tc.legacyConfig,
} }
config, err := toml.TreeFromMap(map[string]interface{}{}) config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -109,7 +112,7 @@ func TestUpdateV1ConfigDefaultRuntime(t *testing.T) {
RuntimeType: runtimeType, RuntimeType: runtimeType,
} }
err = UpdateConfig(v1, o) err = o.UpdateConfig(v1)
require.NoError(t, err, "%d: %v", i, tc) require.NoError(t, err, "%d: %v", i, tc)
defaultRuntimeName := v1.GetPath([]string{"plugins", "cri", "containerd", "default_runtime_name"}) defaultRuntimeName := v1.GetPath([]string{"plugins", "cri", "containerd", "default_runtime_name"})
@@ -138,11 +141,11 @@ func TestUpdateV1Config(t *testing.T) {
const runtimeDir = "/test/runtime/dir" const runtimeDir = "/test/runtime/dir"
testCases := []struct { testCases := []struct {
runtimeClass string runtimeName string
expectedConfig map[string]interface{} expectedConfig map[string]interface{}
}{ }{
{ {
runtimeClass: "nvidia", runtimeName: "nvidia",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(1), "version": int64(1),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -200,7 +203,7 @@ func TestUpdateV1Config(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "NAME", runtimeName: "NAME",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(1), "version": int64(1),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -258,7 +261,7 @@ func TestUpdateV1Config(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(1), "version": int64(1),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -320,9 +323,10 @@ func TestUpdateV1Config(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
runtimeClass: tc.runtimeClass, Options: container.Options{
runtimeType: runtimeType, RuntimeName: tc.runtimeName,
runtimeDir: runtimeDir, RuntimeDir: runtimeDir,
},
} }
config, err := toml.TreeFromMap(map[string]interface{}{}) config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -335,7 +339,7 @@ func TestUpdateV1Config(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"}, ContainerAnnotations: []string{"cdi.k8s.io/*"},
} }
err = UpdateConfig(v1, o) err = o.UpdateConfig(v1)
require.NoError(t, err) require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig) expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -350,11 +354,11 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
const runtimeDir = "/test/runtime/dir" const runtimeDir = "/test/runtime/dir"
testCases := []struct { testCases := []struct {
runtimeClass string runtimeName string
expectedConfig map[string]interface{} expectedConfig map[string]interface{}
}{ }{
{ {
runtimeClass: "nvidia", runtimeName: "nvidia",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(1), "version": int64(1),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -426,7 +430,7 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "NAME", runtimeName: "NAME",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(1), "version": int64(1),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -498,7 +502,7 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(1), "version": int64(1),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -574,9 +578,10 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
runtimeClass: tc.runtimeClass, Options: container.Options{
runtimeType: runtimeType, RuntimeName: tc.runtimeName,
runtimeDir: runtimeDir, RuntimeDir: runtimeDir,
},
} }
config, err := toml.TreeFromMap(runcConfigMapV1("/runc-binary")) config, err := toml.TreeFromMap(runcConfigMapV1("/runc-binary"))
@@ -589,7 +594,7 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"}, ContainerAnnotations: []string{"cdi.k8s.io/*"},
} }
err = UpdateConfig(v1, o) err = o.UpdateConfig(v1)
require.NoError(t, err) require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig) expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -653,7 +658,9 @@ func TestRevertV1Config(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
runtimeClass: "nvidia", Options: container.Options{
RuntimeName: "nvidia",
},
} }
config, err := toml.TreeFromMap(tc.config) config, err := toml.TreeFromMap(tc.config)
@@ -668,7 +675,7 @@ func TestRevertV1Config(t *testing.T) {
RuntimeType: runtimeType, RuntimeType: runtimeType,
} }
err = RevertConfig(v1, o) err = o.RevertConfig(v1)
require.NoError(t, err, "%d: %v", i, tc) require.NoError(t, err, "%d: %v", i, tc)
configContents, _ := toml.Marshal(config) configContents, _ := toml.Marshal(config)

View File

@@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
"github.com/pelletier/go-toml" "github.com/pelletier/go-toml"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -34,38 +35,38 @@ func TestUpdateV2ConfigDefaultRuntime(t *testing.T) {
testCases := []struct { testCases := []struct {
setAsDefault bool setAsDefault bool
runtimeClass string runtimeName string
expectedDefaultRuntimeName interface{} expectedDefaultRuntimeName interface{}
}{ }{
{}, {},
{ {
setAsDefault: false, setAsDefault: false,
runtimeClass: "nvidia", runtimeName: "nvidia",
expectedDefaultRuntimeName: nil, expectedDefaultRuntimeName: nil,
}, },
{ {
setAsDefault: false, setAsDefault: false,
runtimeClass: "NAME", runtimeName: "NAME",
expectedDefaultRuntimeName: nil, expectedDefaultRuntimeName: nil,
}, },
{ {
setAsDefault: false, setAsDefault: false,
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: nil, expectedDefaultRuntimeName: nil,
}, },
{ {
setAsDefault: true, setAsDefault: true,
runtimeClass: "nvidia", runtimeName: "nvidia",
expectedDefaultRuntimeName: "nvidia", expectedDefaultRuntimeName: "nvidia",
}, },
{ {
setAsDefault: true, setAsDefault: true,
runtimeClass: "NAME", runtimeName: "NAME",
expectedDefaultRuntimeName: "NAME", expectedDefaultRuntimeName: "NAME",
}, },
{ {
setAsDefault: true, setAsDefault: true,
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedDefaultRuntimeName: "nvidia-experimental", expectedDefaultRuntimeName: "nvidia-experimental",
}, },
} }
@@ -73,9 +74,11 @@ func TestUpdateV2ConfigDefaultRuntime(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
setAsDefault: tc.setAsDefault, Options: container.Options{
runtimeClass: tc.runtimeClass, RuntimeName: tc.runtimeName,
runtimeDir: runtimeDir, RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
} }
config, err := toml.TreeFromMap(map[string]interface{}{}) config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -86,7 +89,7 @@ func TestUpdateV2ConfigDefaultRuntime(t *testing.T) {
RuntimeType: runtimeType, RuntimeType: runtimeType,
} }
err = UpdateConfig(v2, o) err = o.UpdateConfig(v2)
require.NoError(t, err) require.NoError(t, err)
defaultRuntimeName := config.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "default_runtime_name"}) defaultRuntimeName := config.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "default_runtime_name"})
@@ -100,11 +103,11 @@ func TestUpdateV2Config(t *testing.T) {
const expectedVersion = int64(2) const expectedVersion = int64(2)
testCases := []struct { testCases := []struct {
runtimeClass string runtimeName string
expectedConfig map[string]interface{} expectedConfig map[string]interface{}
}{ }{
{ {
runtimeClass: "nvidia", runtimeName: "nvidia",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(2), "version": int64(2),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -158,7 +161,7 @@ func TestUpdateV2Config(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "NAME", runtimeName: "NAME",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(2), "version": int64(2),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -212,7 +215,7 @@ func TestUpdateV2Config(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(2), "version": int64(2),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -270,9 +273,11 @@ func TestUpdateV2Config(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
runtimeClass: tc.runtimeClass, Options: container.Options{
runtimeType: runtimeType, RuntimeName: tc.runtimeName,
runtimeDir: runtimeDir, RuntimeDir: runtimeDir,
},
runtimeType: runtimeType,
} }
config, err := toml.TreeFromMap(map[string]interface{}{}) config, err := toml.TreeFromMap(map[string]interface{}{})
@@ -284,7 +289,7 @@ func TestUpdateV2Config(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"}, ContainerAnnotations: []string{"cdi.k8s.io/*"},
} }
err = UpdateConfig(v2, o) err = o.UpdateConfig(v2)
require.NoError(t, err) require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig) expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -300,11 +305,11 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
const runtimeDir = "/test/runtime/dir" const runtimeDir = "/test/runtime/dir"
testCases := []struct { testCases := []struct {
runtimeClass string runtimeName string
expectedConfig map[string]interface{} expectedConfig map[string]interface{}
}{ }{
{ {
runtimeClass: "nvidia", runtimeName: "nvidia",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(2), "version": int64(2),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -372,7 +377,7 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "NAME", runtimeName: "NAME",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(2), "version": int64(2),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -440,7 +445,7 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
}, },
}, },
{ {
runtimeClass: "nvidia-experimental", runtimeName: "nvidia-experimental",
expectedConfig: map[string]interface{}{ expectedConfig: map[string]interface{}{
"version": int64(2), "version": int64(2),
"plugins": map[string]interface{}{ "plugins": map[string]interface{}{
@@ -512,9 +517,10 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
runtimeClass: tc.runtimeClass, Options: container.Options{
runtimeType: runtimeType, RuntimeName: tc.runtimeName,
runtimeDir: runtimeDir, RuntimeDir: runtimeDir,
},
} }
config, err := toml.TreeFromMap(runcConfigMapV2("/runc-binary")) config, err := toml.TreeFromMap(runcConfigMapV2("/runc-binary"))
@@ -526,7 +532,7 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
ContainerAnnotations: []string{"cdi.k8s.io/*"}, ContainerAnnotations: []string{"cdi.k8s.io/*"},
} }
err = UpdateConfig(v2, o) err = o.UpdateConfig(v2)
require.NoError(t, err) require.NoError(t, err)
expected, err := toml.TreeFromMap(tc.expectedConfig) expected, err := toml.TreeFromMap(tc.expectedConfig)
@@ -585,7 +591,9 @@ func TestRevertV2Config(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
o := &options{ o := &options{
runtimeClass: "nvidia", Options: container.Options{
RuntimeName: "nvidia",
},
} }
config, err := toml.TreeFromMap(tc.config) config, err := toml.TreeFromMap(tc.config)
@@ -599,7 +607,7 @@ func TestRevertV2Config(t *testing.T) {
RuntimeType: runtimeType, RuntimeType: runtimeType,
} }
err = RevertConfig(v2, o) err = o.RevertConfig(v2)
require.NoError(t, err) require.NoError(t, err)
configContents, _ := toml.Marshal(config) configContents, _ := toml.Marshal(config)

View File

@@ -20,33 +20,23 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"os/exec"
"syscall" "syscall"
"time" "time"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2" cli "github.com/urfave/cli/v2"
) )
const ( const (
restartModeSignal = "signal"
restartModeSystemd = "systemd"
restartModeNone = "none"
nvidiaRuntimeName = "nvidia"
nvidiaRuntimeBinary = "nvidia-container-runtime"
nvidiaExperimentalRuntimeName = "nvidia-experimental"
nvidiaExperimentalRuntimeBinary = "nvidia-container-runtime.experimental"
defaultConfig = "/etc/containerd/config.toml" defaultConfig = "/etc/containerd/config.toml"
defaultSocket = "/run/containerd/containerd.sock" defaultSocket = "/run/containerd/containerd.sock"
defaultRuntimeClass = "nvidia" defaultRuntimeClass = "nvidia"
defaultRuntmeType = "io.containerd.runc.v2" defaultRuntmeType = "io.containerd.runc.v2"
defaultSetAsDefault = true defaultSetAsDefault = true
defaultRestartMode = restartModeSignal defaultRestartMode = "signal"
defaultHostRootMount = "/host" defaultHostRootMount = "/host"
reloadBackoff = 5 * time.Second reloadBackoff = 5 * time.Second
@@ -55,23 +45,13 @@ const (
socketMessageToGetPID = "" socketMessageToGetPID = ""
) )
// nvidiaRuntimeBinaries defines a map of runtime names to binary names
var nvidiaRuntimeBinaries = map[string]string{
nvidiaRuntimeName: nvidiaRuntimeBinary,
nvidiaExperimentalRuntimeName: nvidiaExperimentalRuntimeBinary,
}
// options stores the configuration from the command line or environment variables // options stores the configuration from the command line or environment variables
type options struct { type options struct {
config string container.Options
socket string
runtimeClass string // containerd-specific options
runtimeType string
setAsDefault bool
restartMode string
hostRootMount string
runtimeDir string
useLegacyConfig bool useLegacyConfig bool
runtimeType string
ContainerRuntimeModesCDIAnnotationPrefixes cli.StringSlice ContainerRuntimeModesCDIAnnotationPrefixes cli.StringSlice
} }
@@ -83,7 +63,7 @@ func main() {
c := cli.NewApp() c := cli.NewApp()
c.Name = "containerd" c.Name = "containerd"
c.Usage = "Update a containerd config with the nvidia-container-runtime" c.Usage = "Update a containerd config with the nvidia-container-runtime"
c.Version = "0.1.0" c.Version = info.GetVersionString()
// Create the 'setup' subcommand // Create the 'setup' subcommand
setup := cli.Command{} setup := cli.Command{}
@@ -93,6 +73,9 @@ func main() {
setup.Action = func(c *cli.Context) error { setup.Action = func(c *cli.Context) error {
return Setup(c, &options) return Setup(c, &options)
} }
setup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Create the 'cleanup' subcommand // Create the 'cleanup' subcommand
cleanup := cli.Command{} cleanup := cli.Command{}
@@ -102,6 +85,9 @@ func main() {
cleanup.Action = func(c *cli.Context) error { cleanup.Action = func(c *cli.Context) error {
return Cleanup(c, &options) return Cleanup(c, &options)
} }
cleanup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Register the subcommands with the top-level CLI // Register the subcommands with the top-level CLI
c.Commands = []*cli.Command{ c.Commands = []*cli.Command{
@@ -116,57 +102,53 @@ func main() {
commonFlags := []cli.Flag{ commonFlags := []cli.Flag{
&cli.StringFlag{ &cli.StringFlag{
Name: "config", Name: "config",
Aliases: []string{"c"},
Usage: "Path to the containerd config file", Usage: "Path to the containerd config file",
Value: defaultConfig, Value: defaultConfig,
Destination: &options.config, Destination: &options.Config,
EnvVars: []string{"CONTAINERD_CONFIG"}, EnvVars: []string{"RUNTIME_CONFIG", "CONTAINERD_CONFIG"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "socket", Name: "socket",
Aliases: []string{"s"},
Usage: "Path to the containerd socket file", Usage: "Path to the containerd socket file",
Value: defaultSocket, Value: defaultSocket,
Destination: &options.socket, Destination: &options.Socket,
EnvVars: []string{"CONTAINERD_SOCKET"}, EnvVars: []string{"RUNTIME_SOCKET", "CONTAINERD_SOCKET"},
},
&cli.StringFlag{
Name: "runtime-class",
Aliases: []string{"r"},
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.runtimeClass,
EnvVars: []string{"CONTAINERD_RUNTIME_CLASS"},
},
&cli.StringFlag{
Name: "runtime-type",
Usage: "The runtime_type to use for the configured runtime classes",
Value: defaultRuntmeType,
Destination: &options.runtimeType,
EnvVars: []string{"CONTAINERD_RUNTIME_TYPE"},
},
// The flags below are only used by the 'setup' command.
&cli.BoolFlag{
Name: "set-as-default",
Aliases: []string{"d"},
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.setAsDefault,
EnvVars: []string{"CONTAINERD_SET_AS_DEFAULT"},
Hidden: true,
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "restart-mode", Name: "restart-mode",
Usage: "Specify how containerd should be restarted; If 'none' is selected, it will not be restarted [signal | systemd | none]", Usage: "Specify how containerd should be restarted; If 'none' is selected, it will not be restarted [signal | systemd | none]",
Value: defaultRestartMode, Value: defaultRestartMode,
Destination: &options.restartMode, Destination: &options.RestartMode,
EnvVars: []string{"CONTAINERD_RESTART_MODE"}, EnvVars: []string{"RUNTIME_RESTART_MODE", "CONTAINERD_RESTART_MODE"},
},
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"nvidia-runtime-name", "runtime-class"},
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.RuntimeName,
EnvVars: []string{"NVIDIA_RUNTIME_NAME", "CONTAINERD_RUNTIME_CLASS"},
},
&cli.StringFlag{
Name: "nvidia-runtime-dir",
Aliases: []string{"runtime-dir"},
Usage: "The path where the nvidia-container-runtime binaries are located. If this is not specified, the first argument will be used instead",
Destination: &options.RuntimeDir,
EnvVars: []string{"NVIDIA_RUNTIME_DIR"},
},
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.SetAsDefault,
EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "CONTAINERD_SET_AS_DEFAULT"},
Hidden: true,
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "host-root", Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting containerd using systemd", Usage: "Specify the path to the host root to be used when restarting containerd using systemd",
Value: defaultHostRootMount, Value: defaultHostRootMount,
Destination: &options.hostRootMount, Destination: &options.HostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"}, EnvVars: []string{"HOST_ROOT_MOUNT"},
}, },
&cli.BoolFlag{ &cli.BoolFlag{
@@ -175,6 +157,13 @@ func main() {
Destination: &options.useLegacyConfig, Destination: &options.useLegacyConfig,
EnvVars: []string{"CONTAINERD_USE_LEGACY_CONFIG"}, EnvVars: []string{"CONTAINERD_USE_LEGACY_CONFIG"},
}, },
&cli.StringFlag{
Name: "runtime-type",
Usage: "The runtime_type to use for the configured runtime classes",
Value: defaultRuntmeType,
Destination: &options.runtimeType,
EnvVars: []string{"CONTAINERD_RUNTIME_TYPE"},
},
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "nvidia-container-runtime-modes.cdi.annotation-prefixes", Name: "nvidia-container-runtime-modes.cdi.annotation-prefixes",
Destination: &options.ContainerRuntimeModesCDIAnnotationPrefixes, Destination: &options.ContainerRuntimeModesCDIAnnotationPrefixes,
@@ -196,14 +185,8 @@ func main() {
func Setup(c *cli.Context, o *options) error { func Setup(c *cli.Context, o *options) error {
log.Infof("Starting 'setup' for %v", c.App.Name) log.Infof("Starting 'setup' for %v", c.App.Name)
runtimeDir, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
o.runtimeDir = runtimeDir
cfg, err := containerd.New( cfg, err := containerd.New(
containerd.WithPath(o.config), containerd.WithPath(o.Config),
containerd.WithRuntimeType(o.runtimeType), containerd.WithRuntimeType(o.runtimeType),
containerd.WithUseLegacyConfig(o.useLegacyConfig), containerd.WithUseLegacyConfig(o.useLegacyConfig),
containerd.WithContainerAnnotations(o.containerAnnotationsFromCDIPrefixes()...), containerd.WithContainerAnnotations(o.containerAnnotationsFromCDIPrefixes()...),
@@ -212,18 +195,9 @@ func Setup(c *cli.Context, o *options) error {
return fmt.Errorf("unable to load config: %v", err) return fmt.Errorf("unable to load config: %v", err)
} }
err = UpdateConfig(cfg, o) err = o.Configure(cfg)
if err != nil { if err != nil {
return fmt.Errorf("unable to update config: %v", err) return fmt.Errorf("unable to configure containerd: %v", err)
}
log.Infof("Flushing containerd config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
} }
err = RestartContainerd(o) err = RestartContainerd(o)
@@ -240,13 +214,8 @@ func Setup(c *cli.Context, o *options) error {
func Cleanup(c *cli.Context, o *options) error { func Cleanup(c *cli.Context, o *options) error {
log.Infof("Starting 'cleanup' for %v", c.App.Name) log.Infof("Starting 'cleanup' for %v", c.App.Name)
_, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
cfg, err := containerd.New( cfg, err := containerd.New(
containerd.WithPath(o.config), containerd.WithPath(o.Config),
containerd.WithRuntimeType(o.runtimeType), containerd.WithRuntimeType(o.runtimeType),
containerd.WithUseLegacyConfig(o.useLegacyConfig), containerd.WithUseLegacyConfig(o.useLegacyConfig),
containerd.WithContainerAnnotations(o.containerAnnotationsFromCDIPrefixes()...), containerd.WithContainerAnnotations(o.containerAnnotationsFromCDIPrefixes()...),
@@ -255,18 +224,9 @@ func Cleanup(c *cli.Context, o *options) error {
return fmt.Errorf("unable to load config: %v", err) return fmt.Errorf("unable to load config: %v", err)
} }
err = RevertConfig(cfg, o) err = o.Unconfigure(cfg)
if err != nil { if err != nil {
return fmt.Errorf("unable to update config: %v", err) return fmt.Errorf("unable to unconfigure containerd: %v", err)
}
log.Infof("Flushing containerd config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
} }
err = RestartContainerd(o) err = RestartContainerd(o)
@@ -279,80 +239,18 @@ func Cleanup(c *cli.Context, o *options) error {
return nil return nil
} }
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context) (string, error) {
args := c.Args()
log.Infof("Parsing arguments: %v", args.Slice())
if args.Len() != 1 {
return "", fmt.Errorf("incorrect number of arguments")
}
runtimeDir := args.Get(0)
log.Infof("Successfully parsed arguments")
return runtimeDir, nil
}
// UpdateConfig updates the containerd config to include the nvidia-container-runtime
func UpdateConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class, runtime := range runtimes {
err := cfg.AddRuntime(class, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("unable to update config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RevertConfig reverts the containerd config to remove the nvidia-container-runtime
func RevertConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class := range runtimes {
err := cfg.RemoveRuntime(class)
if err != nil {
return fmt.Errorf("unable to revert config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RestartContainerd restarts containerd depending on the value of restartModeFlag // RestartContainerd restarts containerd depending on the value of restartModeFlag
func RestartContainerd(o *options) error { func RestartContainerd(o *options) error {
switch o.restartMode { return o.Restart("containerd", SignalContainerd)
case restartModeNone:
log.Warnf("Skipping sending signal to containerd due to --restart-mode=%v", o.restartMode)
return nil
case restartModeSignal:
err := SignalContainerd(o)
if err != nil {
return fmt.Errorf("unable to signal containerd: %v", err)
}
case restartModeSystemd:
return RestartContainerdSystemd(o.hostRootMount)
default:
return fmt.Errorf("Invalid restart mode specified: %v", o.restartMode)
}
return nil
} }
// SignalContainerd sends a SIGHUP signal to the containerd daemon // SignalContainerd sends a SIGHUP signal to the containerd daemon
func SignalContainerd(o *options) error { func SignalContainerd(socket string) error {
log.Infof("Sending SIGHUP signal to containerd") log.Infof("Sending SIGHUP signal to containerd")
// Wrap the logic to perform the SIGHUP in a function so we can retry it on failure // Wrap the logic to perform the SIGHUP in a function so we can retry it on failure
retriable := func() error { retriable := func() error {
conn, err := net.Dial("unix", o.socket) conn, err := net.Dial("unix", socket)
if err != nil { if err != nil {
return fmt.Errorf("unable to dial: %v", err) return fmt.Errorf("unable to dial: %v", err)
} }
@@ -426,24 +324,6 @@ func SignalContainerd(o *options) error {
return nil return nil
} }
// RestartContainerdSystemd restarts containerd using systemctl
func RestartContainerdSystemd(hostRootMount string) error {
log.Infof("Restarting containerd using systemd and host root mounted at %v", hostRootMount)
command := "chroot"
args := []string{hostRootMount, "systemctl", "restart", "containerd"}
cmd := exec.Command(command, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error restarting containerd using systemd: %v", err)
}
return nil
}
// containerAnnotationsFromCDIPrefixes returns the container annotations to set for the given CDI prefixes. // containerAnnotationsFromCDIPrefixes returns the container annotations to set for the given CDI prefixes.
func (o *options) containerAnnotationsFromCDIPrefixes() []string { func (o *options) containerAnnotationsFromCDIPrefixes() []string {
var annotations []string var annotations []string

View File

@@ -20,21 +20,17 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/crio" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/crio"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2" cli "github.com/urfave/cli/v2"
) )
const ( const (
restartModeSystemd = "systemd"
restartModeNone = "none"
defaultConfigMode = "hook" defaultConfigMode = "hook"
// Hook-based settings // Hook-based settings
@@ -43,25 +39,22 @@ const (
// Config-based settings // Config-based settings
defaultConfig = "/etc/crio/crio.conf" defaultConfig = "/etc/crio/crio.conf"
defaultSocket = "/var/run/crio/crio.sock"
defaultRuntimeClass = "nvidia" defaultRuntimeClass = "nvidia"
defaultSetAsDefault = true defaultSetAsDefault = true
defaultRestartMode = restartModeSystemd defaultRestartMode = "systemd"
defaultHostRootMount = "/host" defaultHostRootMount = "/host"
) )
// options stores the configuration from the command linek or environment variables // options stores the configuration from the command linek or environment variables
type options struct { type options struct {
container.Options
configMode string configMode string
// hook-specific options
hooksDir string hooksDir string
hookFilename string hookFilename string
runtimeDir string
config string
runtimeClass string
setAsDefault bool
restartMode string
hostRootMount string
} }
func main() { func main() {
@@ -71,8 +64,7 @@ func main() {
c := cli.NewApp() c := cli.NewApp()
c.Name = "crio" c.Name = "crio"
c.Usage = "Update cri-o hooks to include the NVIDIA runtime hook" c.Usage = "Update cri-o hooks to include the NVIDIA runtime hook"
c.ArgsUsage = "<toolkit_dirname>" c.Version = info.GetVersionString()
c.Version = "0.1.0"
// Create the 'setup' subcommand // Create the 'setup' subcommand
setup := cli.Command{} setup := cli.Command{}
@@ -83,7 +75,7 @@ func main() {
return Setup(c, &options) return Setup(c, &options)
} }
setup.Before = func(c *cli.Context) error { setup.Before = func(c *cli.Context) error {
return ParseArgs(c, &options) return container.ParseArgs(c, &options.Options)
} }
// Create the 'cleanup' subcommand // Create the 'cleanup' subcommand
@@ -93,6 +85,10 @@ func main() {
cleanup.Action = func(c *cli.Context) error { cleanup.Action = func(c *cli.Context) error {
return Cleanup(c, &options) return Cleanup(c, &options)
} }
cleanup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Register the subcommands with the top-level CLI // Register the subcommands with the top-level CLI
c.Commands = []*cli.Command{ c.Commands = []*cli.Command{
&setup, &setup,
@@ -104,9 +100,61 @@ func main() {
// only require the user to specify one set of flags for both 'startup' // only require the user to specify one set of flags for both 'startup'
// and 'cleanup' to simplify things. // and 'cleanup' to simplify things.
commonFlags := []cli.Flag{ commonFlags := []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "Path to the cri-o config file",
Value: defaultConfig,
Destination: &options.Config,
EnvVars: []string{"RUNTIME_CONFIG", "CRIO_CONFIG"},
},
&cli.StringFlag{
Name: "socket",
Usage: "Path to the crio socket file",
Value: "",
Destination: &options.Socket,
EnvVars: []string{"RUNTIME_SOCKET", "CRIO_SOCKET"},
// Note: We hide this option since restarting cri-o via a socket is not supported.
Hidden: true,
},
&cli.StringFlag{
Name: "restart-mode",
Usage: "Specify how cri-o should be restarted; If 'none' is selected, it will not be restarted [systemd | none]",
Value: defaultRestartMode,
Destination: &options.RestartMode,
EnvVars: []string{"RUNTIME_RESTART_MODE", "CRIO_RESTART_MODE"},
},
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"nvidia-runtime-name", "runtime-class"},
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.RuntimeName,
EnvVars: []string{"NVIDIA_RUNTIME_NAME", "CRIO_RUNTIME_CLASS"},
},
&cli.StringFlag{
Name: "nvidia-runtime-dir",
Aliases: []string{"runtime-dir"},
Usage: "The path where the nvidia-container-runtime binaries are located. If this is not specified, the first argument will be used instead",
Destination: &options.RuntimeDir,
EnvVars: []string{"NVIDIA_RUNTIME_DIR"},
},
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.SetAsDefault,
EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "CRIO_SET_AS_DEFAULT"},
Hidden: true,
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting crio using systemd",
Value: defaultHostRootMount,
Destination: &options.HostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"},
},
&cli.StringFlag{ &cli.StringFlag{
Name: "hooks-dir", Name: "hooks-dir",
Aliases: []string{"d"},
Usage: "path to the cri-o hooks directory", Usage: "path to the cri-o hooks directory",
Value: defaultHooksDir, Value: defaultHooksDir,
Destination: &options.hooksDir, Destination: &options.hooksDir,
@@ -115,7 +163,6 @@ func main() {
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "hook-filename", Name: "hook-filename",
Aliases: []string{"f"},
Usage: "filename of the cri-o hook that will be created / removed in the hooks directory", Usage: "filename of the cri-o hook that will be created / removed in the hooks directory",
Value: defaultHookFilename, Value: defaultHookFilename,
Destination: &options.hookFilename, Destination: &options.hookFilename,
@@ -129,43 +176,6 @@ func main() {
Destination: &options.configMode, Destination: &options.configMode,
EnvVars: []string{"CRIO_CONFIG_MODE"}, EnvVars: []string{"CRIO_CONFIG_MODE"},
}, },
&cli.StringFlag{
Name: "config",
Usage: "Path to the cri-o config file",
Value: defaultConfig,
Destination: &options.config,
EnvVars: []string{"CRIO_CONFIG"},
},
&cli.StringFlag{
Name: "runtime-class",
Usage: "The name of the runtime class to set for the nvidia-container-runtime",
Value: defaultRuntimeClass,
Destination: &options.runtimeClass,
EnvVars: []string{"CRIO_RUNTIME_CLASS"},
},
// The flags below are only used by the 'setup' command.
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set nvidia-container-runtime as the default runtime",
Value: defaultSetAsDefault,
Destination: &options.setAsDefault,
EnvVars: []string{"CRIO_SET_AS_DEFAULT"},
Hidden: true,
},
&cli.StringFlag{
Name: "restart-mode",
Usage: "Specify how cri-o should be restarted; If 'none' is selected, it will not be restarted [systemd | none]",
Value: defaultRestartMode,
Destination: &options.restartMode,
EnvVars: []string{"CRIO_RESTART_MODE"},
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting crio using systemd",
Value: defaultHostRootMount,
Destination: &options.hostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"},
},
} }
// Update the subcommand flags with the common subcommand flags // Update the subcommand flags with the common subcommand flags
@@ -202,7 +212,7 @@ func setupHook(o *options) error {
} }
hookPath := getHookPath(o.hooksDir, o.hookFilename) hookPath := getHookPath(o.hooksDir, o.hookFilename)
err = createHook(o.runtimeDir, hookPath) err = createHook(o.RuntimeDir, hookPath)
if err != nil { if err != nil {
return fmt.Errorf("error creating hook: %v", err) return fmt.Errorf("error creating hook: %v", err)
} }
@@ -215,24 +225,15 @@ func setupConfig(o *options) error {
log.Infof("Updating config file") log.Infof("Updating config file")
cfg, err := crio.New( cfg, err := crio.New(
crio.WithPath(o.config), crio.WithPath(o.Config),
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to load config: %v", err) return fmt.Errorf("unable to load config: %v", err)
} }
err = UpdateConfig(cfg, o) err = o.Configure(cfg)
if err != nil { if err != nil {
return fmt.Errorf("unable to update config: %v", err) return fmt.Errorf("unable to configure cri-o: %v", err)
}
log.Infof("Flushing cri-o config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
} }
err = RestartCrio(o) err = RestartCrio(o)
@@ -275,24 +276,15 @@ func cleanupConfig(o *options) error {
log.Infof("Reverting config file modifications") log.Infof("Reverting config file modifications")
cfg, err := crio.New( cfg, err := crio.New(
crio.WithPath(o.config), crio.WithPath(o.Config),
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to load config: %v", err) return fmt.Errorf("unable to load config: %v", err)
} }
err = RevertConfig(cfg, o) err = o.Unconfigure(cfg)
if err != nil { if err != nil {
return fmt.Errorf("unable to update config: %v", err) return fmt.Errorf("unable to unconfigure cri-o: %v", err)
}
log.Infof("Flushing cri-o config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
} }
err = RestartCrio(o) err = RestartCrio(o)
@@ -303,20 +295,6 @@ func cleanupConfig(o *options) error {
return nil return nil
} }
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context, o *options) error {
args := c.Args()
log.Infof("Parsing arguments: %v", args.Slice())
if c.NArg() != 1 {
return fmt.Errorf("incorrect number of arguments")
}
o.runtimeDir = args.Get(0)
log.Infof("Successfully parsed arguments")
return nil
}
func createHook(toolkitDir string, hookPath string) error { func createHook(toolkitDir string, hookPath string) error {
hook, err := os.Create(hookPath) hook, err := os.Create(hookPath)
if err != nil { if err != nil {
@@ -357,66 +335,7 @@ func generateOciHook(toolkitDir string) podmanHook {
return hook return hook
} }
// UpdateConfig updates the cri-o config to include the NVIDIA Container Runtime
func UpdateConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class, runtime := range runtimes {
err := cfg.AddRuntime(class, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("unable to update config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RevertConfig reverts the cri-o config to remove the NVIDIA Container Runtime
func RevertConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeClass),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for class := range runtimes {
err := cfg.RemoveRuntime(class)
if err != nil {
return fmt.Errorf("unable to revert config for runtime class '%v': %v", class, err)
}
}
return nil
}
// RestartCrio restarts crio depending on the value of restartModeFlag // RestartCrio restarts crio depending on the value of restartModeFlag
func RestartCrio(o *options) error { func RestartCrio(o *options) error {
switch o.restartMode { return o.Restart("crio", func(string) error { return fmt.Errorf("supporting crio via signal is unsupported") })
case restartModeNone:
log.Warnf("Skipping restart of crio due to --restart-mode=%v", o.restartMode)
return nil
case restartModeSystemd:
return RestartCrioSystemd(o.hostRootMount)
default:
return fmt.Errorf("invalid restart mode specified: %v", o.restartMode)
}
}
// RestartCrioSystemd restarts cri-o using systemctl
func RestartCrioSystemd(hostRootMount string) error {
log.Infof("Restarting cri-o using systemd and host root mounted at %v", hostRootMount)
command := "chroot"
args := []string{hostRootMount, "systemctl", "restart", "crio"}
cmd := exec.Command(command, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error restarting crio using systemd: %v", err)
}
return nil
} }

View File

@@ -23,50 +23,31 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/docker" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/docker"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2" cli "github.com/urfave/cli/v2"
) )
const ( const (
restartModeSignal = "signal"
restartModeNone = "none"
nvidiaRuntimeName = "nvidia"
nvidiaRuntimeBinary = "nvidia-container-runtime"
nvidiaExperimentalRuntimeName = "nvidia-experimental"
nvidiaExperimentalRuntimeBinary = "nvidia-container-runtime.experimental"
defaultConfig = "/etc/docker/daemon.json" defaultConfig = "/etc/docker/daemon.json"
defaultSocket = "/var/run/docker.sock" defaultSocket = "/var/run/docker.sock"
defaultSetAsDefault = true defaultSetAsDefault = true
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled // defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
defaultRuntimeName = nvidiaRuntimeName defaultRuntimeName = "nvidia"
defaultRestartMode = restartModeSignal defaultRestartMode = "signal"
defaultHostRootMount = "/host"
reloadBackoff = 5 * time.Second reloadBackoff = 5 * time.Second
maxReloadAttempts = 6 maxReloadAttempts = 6
defaultDockerRuntime = "runc"
socketMessageToGetPID = "GET /info HTTP/1.0\r\n\r\n" socketMessageToGetPID = "GET /info HTTP/1.0\r\n\r\n"
) )
// nvidiaRuntimeBinaries defines a map of runtime names to binary names
var nvidiaRuntimeBinaries = map[string]string{
nvidiaRuntimeName: nvidiaRuntimeBinary,
nvidiaExperimentalRuntimeName: nvidiaExperimentalRuntimeBinary,
}
// options stores the configuration from the command line or environment variables // options stores the configuration from the command line or environment variables
type options struct { type options struct {
config string container.Options
socket string
runtimeName string
setAsDefault bool
runtimeDir string
restartMode string
} }
func main() { func main() {
@@ -76,7 +57,7 @@ func main() {
c := cli.NewApp() c := cli.NewApp()
c.Name = "docker" c.Name = "docker"
c.Usage = "Update docker config with the nvidia runtime" c.Usage = "Update docker config with the nvidia runtime"
c.Version = "0.1.0" c.Version = info.GetVersionString()
// Create the 'setup' subcommand // Create the 'setup' subcommand
setup := cli.Command{} setup := cli.Command{}
@@ -86,6 +67,9 @@ func main() {
setup.Action = func(c *cli.Context) error { setup.Action = func(c *cli.Context) error {
return Setup(c, &options) return Setup(c, &options)
} }
setup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Create the 'cleanup' subcommand // Create the 'cleanup' subcommand
cleanup := cli.Command{} cleanup := cli.Command{}
@@ -95,6 +79,9 @@ func main() {
cleanup.Action = func(c *cli.Context) error { cleanup.Action = func(c *cli.Context) error {
return Cleanup(c, &options) return Cleanup(c, &options)
} }
cleanup.Before = func(c *cli.Context) error {
return container.ParseArgs(c, &options.Options)
}
// Register the subcommands with the top-level CLI // Register the subcommands with the top-level CLI
c.Commands = []*cli.Command{ c.Commands = []*cli.Command{
@@ -109,44 +96,57 @@ func main() {
commonFlags := []cli.Flag{ commonFlags := []cli.Flag{
&cli.StringFlag{ &cli.StringFlag{
Name: "config", Name: "config",
Aliases: []string{"c"},
Usage: "Path to docker config file", Usage: "Path to docker config file",
Value: defaultConfig, Value: defaultConfig,
Destination: &options.config, Destination: &options.Config,
EnvVars: []string{"DOCKER_CONFIG"}, EnvVars: []string{"RUNTIME_CONFIG", "DOCKER_CONFIG"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "socket", Name: "socket",
Aliases: []string{"s"},
Usage: "Path to the docker socket file", Usage: "Path to the docker socket file",
Value: defaultSocket, Value: defaultSocket,
Destination: &options.socket, Destination: &options.Socket,
EnvVars: []string{"DOCKER_SOCKET"}, EnvVars: []string{"RUNTIME_SOCKET", "DOCKER_SOCKET"},
},
// The flags below are only used by the 'setup' command.
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"r"},
Usage: "Specify the name of the `nvidia` runtime. If set-as-default is selected, the runtime is used as the default runtime.",
Value: defaultRuntimeName,
Destination: &options.runtimeName,
EnvVars: []string{"DOCKER_RUNTIME_NAME"},
},
&cli.BoolFlag{
Name: "set-as-default",
Aliases: []string{"d"},
Usage: "Set the `nvidia` runtime as the default runtime. If --runtime-name is specified as `nvidia-experimental` the experimental runtime is set as the default runtime instead",
Value: defaultSetAsDefault,
Destination: &options.setAsDefault,
EnvVars: []string{"DOCKER_SET_AS_DEFAULT"},
Hidden: true,
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "restart-mode", Name: "restart-mode",
Usage: "Specify how docker should be restarted; If 'none' is selected it will not be restarted [signal | none]", Usage: "Specify how docker should be restarted; If 'none' is selected it will not be restarted [signal | systemd | none ]",
Value: defaultRestartMode, Value: defaultRestartMode,
Destination: &options.restartMode, Destination: &options.RestartMode,
EnvVars: []string{"DOCKER_RESTART_MODE"}, EnvVars: []string{"RUNTIME_RESTART_MODE", "DOCKER_RESTART_MODE"},
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting docker using systemd",
Value: defaultHostRootMount,
Destination: &options.HostRootMount,
EnvVars: []string{"HOST_ROOT_MOUNT"},
// Restart using systemd is currently not supported.
// We hide this option for the time being.
Hidden: true,
},
&cli.StringFlag{
Name: "runtime-name",
Aliases: []string{"nvidia-runtime-name", "runtime-class"},
Usage: "Specify the name of the `nvidia` runtime. If set-as-default is selected, the runtime is used as the default runtime.",
Value: defaultRuntimeName,
Destination: &options.RuntimeName,
EnvVars: []string{"NVIDIA_RUNTIME_NAME", "DOCKER_RUNTIME_NAME"},
},
&cli.StringFlag{
Name: "nvidia-runtime-dir",
Aliases: []string{"runtime-dir"},
Usage: "The path where the nvidia-container-runtime binaries are located. If this is not specified, the first argument will be used instead",
Destination: &options.RuntimeDir,
EnvVars: []string{"NVIDIA_RUNTIME_DIR"},
},
&cli.BoolFlag{
Name: "set-as-default",
Usage: "Set the `nvidia` runtime as the default runtime. If --runtime-name is specified as `nvidia-experimental` the experimental runtime is set as the default runtime instead",
Value: defaultSetAsDefault,
Destination: &options.SetAsDefault,
EnvVars: []string{"NVIDIA_RUNTIME_SET_AS_DEFAULT", "DOCKER_SET_AS_DEFAULT"},
Hidden: true,
}, },
} }
@@ -165,28 +165,16 @@ func main() {
func Setup(c *cli.Context, o *options) error { func Setup(c *cli.Context, o *options) error {
log.Infof("Starting 'setup' for %v", c.App.Name) log.Infof("Starting 'setup' for %v", c.App.Name)
runtimeDir, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
o.runtimeDir = runtimeDir
cfg, err := docker.New( cfg, err := docker.New(
docker.WithPath(o.config), docker.WithPath(o.Config),
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to load config: %v", err) return fmt.Errorf("unable to load config: %v", err)
} }
err = UpdateConfig(cfg, o) err = o.Configure(cfg)
if err != nil { if err != nil {
return fmt.Errorf("unable to update config: %v", err) return fmt.Errorf("unable to configure docker: %v", err)
}
log.Infof("Flushing docker config to %v", o.config)
_, err = cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
} }
err = RestartDocker(o) err = RestartDocker(o)
@@ -203,30 +191,16 @@ func Setup(c *cli.Context, o *options) error {
func Cleanup(c *cli.Context, o *options) error { func Cleanup(c *cli.Context, o *options) error {
log.Infof("Starting 'cleanup' for %v", c.App.Name) log.Infof("Starting 'cleanup' for %v", c.App.Name)
_, err := ParseArgs(c)
if err != nil {
return fmt.Errorf("unable to parse args: %v", err)
}
cfg, err := docker.New( cfg, err := docker.New(
docker.WithPath(o.config), docker.WithPath(o.Config),
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to load config: %v", err) return fmt.Errorf("unable to load config: %v", err)
} }
err = RevertConfig(cfg, o) err = o.Unconfigure(cfg)
if err != nil { if err != nil {
return fmt.Errorf("unable to update config: %v", err) return fmt.Errorf("unable to unconfigure docker: %v", err)
}
log.Infof("Flushing docker config to %v", o.config)
n, err := cfg.Save(o.config)
if err != nil {
return fmt.Errorf("unable to flush config: %v", err)
}
if n == 0 {
log.Infof("Config file is empty, removed")
} }
err = RestartDocker(o) err = RestartDocker(o)
@@ -239,69 +213,9 @@ func Cleanup(c *cli.Context, o *options) error {
return nil return nil
} }
// ParseArgs parses the command line arguments to the CLI
func ParseArgs(c *cli.Context) (string, error) {
args := c.Args()
log.Infof("Parsing arguments: %v", args.Slice())
if args.Len() != 1 {
return "", fmt.Errorf("incorrect number of arguments")
}
runtimeDir := args.Get(0)
log.Infof("Successfully parsed arguments")
return runtimeDir, nil
}
// UpdateConfig updates the docker config to include the nvidia runtimes
func UpdateConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeName),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for name, runtime := range runtimes {
err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault)
if err != nil {
return fmt.Errorf("failed to update runtime %q: %v", name, err)
}
}
return nil
}
// RevertConfig reverts the docker config to remove the nvidia runtime
func RevertConfig(cfg engine.Interface, o *options) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.runtimeName),
operator.WithSetAsDefault(o.setAsDefault),
operator.WithRoot(o.runtimeDir),
)
for name := range runtimes {
err := cfg.RemoveRuntime(name)
if err != nil {
return fmt.Errorf("failed to remove runtime %q: %v", name, err)
}
}
return nil
}
// RestartDocker restarts docker depending on the value of restartModeFlag // RestartDocker restarts docker depending on the value of restartModeFlag
func RestartDocker(o *options) error { func RestartDocker(o *options) error {
switch o.restartMode { return o.Restart("docker", SignalDocker)
case restartModeNone:
log.Warnf("Skipping sending signal to docker due to --restart-mode=%v", o.restartMode)
case restartModeSignal:
err := SignalDocker(o.socket)
if err != nil {
return fmt.Errorf("unable to signal docker: %v", err)
}
default:
return fmt.Errorf("invalid restart mode specified: %v", o.restartMode)
}
return nil
} }
// SignalDocker sends a SIGHUP signal to docker daemon // SignalDocker sends a SIGHUP signal to docker daemon

View File

@@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/docker" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/engine/docker"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -56,14 +57,16 @@ func TestUpdateConfigDefaultRuntime(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
o := &options{ o := &options{
setAsDefault: tc.setAsDefault, Options: container.Options{
runtimeName: tc.runtimeName, RuntimeName: tc.runtimeName,
runtimeDir: runtimeDir, RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
} }
config := docker.Config(map[string]interface{}{}) config := docker.Config(map[string]interface{}{})
err := UpdateConfig(&config, o) err := o.UpdateConfig(&config)
require.NoError(t, err, "%d: %v", i, tc) require.NoError(t, err, "%d: %v", i, tc)
defaultRuntimeName := config["default-runtime"] defaultRuntimeName := config["default-runtime"]
@@ -314,13 +317,15 @@ func TestUpdateConfig(t *testing.T) {
} }
for i, tc := range testCases { for i, tc := range testCases {
options := &options{ o := &options{
setAsDefault: tc.setAsDefault, Options: container.Options{
runtimeName: tc.runtimeName, RuntimeName: tc.runtimeName,
runtimeDir: runtimeDir, RuntimeDir: runtimeDir,
SetAsDefault: tc.setAsDefault,
},
} }
err := UpdateConfig(&tc.config, options) err := o.UpdateConfig(&tc.config)
require.NoError(t, err, "%d: %v", i, tc) require.NoError(t, err, "%d: %v", i, tc)
configContent, err := json.MarshalIndent(tc.config, "", " ") configContent, err := json.MarshalIndent(tc.config, "", " ")
@@ -457,7 +462,8 @@ func TestRevertConfig(t *testing.T) {
} }
for i, tc := range testCases { for i, tc := range testCases {
err := RevertConfig(&tc.config, &options{}) o := &options{}
err := o.RevertConfig(&tc.config)
require.NoError(t, err, "%d: %v", i, tc) require.NoError(t, err, "%d: %v", i, tc)

View File

@@ -23,7 +23,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system" "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
@@ -305,7 +305,7 @@ func Install(cli *cli.Context, opts *options) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container CLI: %v", err)) log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container CLI: %v", err))
} }
_, err = installRuntimeHook(opts.toolkitRoot, toolkitConfigPath) nvidiaContainerRuntimeHookPath, err := installRuntimeHook(opts.toolkitRoot, toolkitConfigPath)
if err != nil && !opts.ignoreErrors { if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container runtime hook: %v", err) return fmt.Errorf("error installing NVIDIA container runtime hook: %v", err)
} else if err != nil { } else if err != nil {
@@ -319,7 +319,7 @@ func Install(cli *cli.Context, opts *options) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA Container Toolkit CLI: %v", err)) log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA Container Toolkit CLI: %v", err))
} }
err = installToolkitConfig(cli, toolkitConfigPath, nvidiaContainerCliExecutable, nvidiaCTKPath, opts) err = installToolkitConfig(cli, toolkitConfigPath, nvidiaContainerCliExecutable, nvidiaCTKPath, nvidiaContainerRuntimeHookPath, opts)
if err != nil && !opts.ignoreErrors { if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container toolkit config: %v", err) return fmt.Errorf("error installing NVIDIA container toolkit config: %v", err)
} else if err != nil { } else if err != nil {
@@ -379,7 +379,7 @@ func installLibrary(libName string, toolkitRoot string) error {
// installToolkitConfig installs the config file for the NVIDIA container toolkit ensuring // installToolkitConfig installs the config file for the NVIDIA container toolkit ensuring
// that the settings are updated to match the desired install and nvidia driver directories. // that the settings are updated to match the desired install and nvidia driver directories.
func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContainerCliExecutablePath string, nvidiaCTKPath string, opts *options) error { func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContainerCliExecutablePath string, nvidiaCTKPath string, nvidaContainerRuntimeHookPath string, opts *options) error {
log.Infof("Installing NVIDIA container toolkit config '%v'", toolkitConfigPath) log.Infof("Installing NVIDIA container toolkit config '%v'", toolkitConfigPath)
config, err := loadConfig(nvidiaContainerToolkitConfigSource) config, err := loadConfig(nvidiaContainerToolkitConfigSource)
@@ -410,6 +410,7 @@ func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContai
// Set nvidia-ctk options // Set nvidia-ctk options
"nvidia-ctk.path": nvidiaCTKPath, "nvidia-ctk.path": nvidiaCTKPath,
// Set the nvidia-container-runtime-hook options // Set the nvidia-container-runtime-hook options
"nvidia-container-runtime-hook.path": nvidaContainerRuntimeHookPath,
"nvidia-container-runtime-hook.skip-mode-detection": opts.ContainerRuntimeHookSkipModeDetection, "nvidia-container-runtime-hook.skip-mode-detection": opts.ContainerRuntimeHookSkipModeDetection,
} }
for key, value := range configValues { for key, value := range configValues {
@@ -682,11 +683,13 @@ func generateCDISpec(opts *options, nvidiaCTKPath string) error {
} }
log.Infof("Creating control device nodes at %v", opts.DriverRootCtrPath) log.Infof("Creating control device nodes at %v", opts.DriverRootCtrPath)
s, err := system.New() devices, err := nvdevices.New(
nvdevices.WithDevRoot(opts.DriverRootCtrPath),
)
if err != nil { if err != nil {
return fmt.Errorf("failed to create library: %v", err) return fmt.Errorf("failed to create library: %v", err)
} }
if err := s.CreateNVIDIAControlDeviceNodesAt(opts.DriverRootCtrPath); err != nil { if err := devices.CreateNVIDIAControlDevices(); err != nil {
return fmt.Errorf("failed to create control device nodes: %v", err) return fmt.Errorf("failed to create control device nodes: %v", err)
} }

View File

@@ -1,2 +1,2 @@
toml.test /toml.test
/toml-test /toml-test

View File

@@ -1 +0,0 @@
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0).

View File

@@ -1,6 +1,5 @@
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml` reflection interface similar to Go's standard library `json` and `xml` packages.
packages.
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0). Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0).
@@ -10,7 +9,7 @@ See the [releases page](https://github.com/BurntSushi/toml/releases) for a
changelog; this information is also in the git tag annotations (e.g. `git show changelog; this information is also in the git tag annotations (e.g. `git show
v0.4.0`). v0.4.0`).
This library requires Go 1.13 or newer; install it with: This library requires Go 1.13 or newer; add it to your go.mod with:
% go get github.com/BurntSushi/toml@latest % go get github.com/BurntSushi/toml@latest
@@ -19,16 +18,7 @@ It also comes with a TOML validator CLI tool:
% go install github.com/BurntSushi/toml/cmd/tomlv@latest % go install github.com/BurntSushi/toml/cmd/tomlv@latest
% tomlv some-toml-file.toml % tomlv some-toml-file.toml
### Testing
This package passes all tests in [toml-test] for both the decoder and the
encoder.
[toml-test]: https://github.com/BurntSushi/toml-test
### Examples ### Examples
This package works similar to how the Go standard library handles XML and JSON.
Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys and For the simplest example, consider some TOML file as just a list of keys and
values: values:
@@ -40,7 +30,7 @@ Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z DOB = 1987-07-05T05:45:00Z
``` ```
Which could be defined in Go as: Which can be decoded with:
```go ```go
type Config struct { type Config struct {
@@ -48,20 +38,15 @@ type Config struct {
Cats []string Cats []string
Pi float64 Pi float64
Perfection []int Perfection []int
DOB time.Time // requires `import time` DOB time.Time
} }
```
And then decoded with:
```go
var conf Config var conf Config
err := toml.Decode(tomlData, &conf) _, err := toml.Decode(tomlData, &conf)
// handle error
``` ```
You can also use struct tags if your struct field name doesn't map to a TOML You can also use struct tags if your struct field name doesn't map to a TOML key
key value directly: value directly:
```toml ```toml
some_key_NAME = "wat" some_key_NAME = "wat"
@@ -73,139 +58,63 @@ type TOML struct {
} }
``` ```
Beware that like other most other decoders **only exported fields** are Beware that like other decoders **only exported fields** are considered when
considered when encoding and decoding; private fields are silently ignored. encoding and decoding; private fields are silently ignored.
### Using the `Marshaler` and `encoding.TextUnmarshaler` interfaces ### Using the `Marshaler` and `encoding.TextUnmarshaler` interfaces
Here's an example that automatically parses duration strings into Here's an example that automatically parses values in a `mail.Address`:
`time.Duration` values:
```toml ```toml
[[song]] contacts = [
name = "Thunder Road" "Donald Duck <donald@duckburg.com>",
duration = "4m49s" "Scrooge McDuck <scrooge@duckburg.com>",
]
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
``` ```
Which can be decoded with: Can be decoded with:
```go ```go
type song struct { // Create address type which satisfies the encoding.TextUnmarshaler interface.
Name string type address struct {
Duration duration *mail.Address
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
} }
for _, s := range favorites.Song { func (a *address) UnmarshalText(text []byte) error {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error var err error
d.Duration, err = time.ParseDuration(string(text)) a.Address, err = mail.ParseAddress(string(text))
return err return err
} }
// Decode it.
func decode() {
blob := `
contacts = [
"Donald Duck <donald@duckburg.com>",
"Scrooge McDuck <scrooge@duckburg.com>",
]
`
var contacts struct {
Contacts []address
}
_, err := toml.Decode(blob, &contacts)
if err != nil {
log.Fatal(err)
}
for _, c := range contacts.Contacts {
fmt.Printf("%#v\n", c.Address)
}
// Output:
// &mail.Address{Name:"Donald Duck", Address:"donald@duckburg.com"}
// &mail.Address{Name:"Scrooge McDuck", Address:"scrooge@duckburg.com"}
}
``` ```
To target TOML specifically you can implement `UnmarshalTOML` TOML interface in To target TOML specifically you can implement `UnmarshalTOML` TOML interface in
a similar way. a similar way.
### More complex usage ### More complex usage
Here's an example of how to load the example from the official spec page: See the [`_example/`](/_example) directory for a more complex example.
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_example/example.{go,toml}`.

View File

@@ -1,14 +1,18 @@
package toml package toml
import ( import (
"bytes"
"encoding" "encoding"
"encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math" "math"
"os" "os"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time"
) )
// Unmarshaler is the interface implemented by objects that can unmarshal a // Unmarshaler is the interface implemented by objects that can unmarshal a
@@ -17,16 +21,35 @@ type Unmarshaler interface {
UnmarshalTOML(interface{}) error UnmarshalTOML(interface{}) error
} }
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`. // Unmarshal decodes the contents of data in TOML format into a pointer v.
func Unmarshal(p []byte, v interface{}) error { //
_, err := Decode(string(p), v) // See [Decoder] for a description of the decoding process.
func Unmarshal(data []byte, v interface{}) error {
_, err := NewDecoder(bytes.NewReader(data)).Decode(v)
return err return err
} }
// Decode the TOML data in to the pointer v.
//
// See [Decoder] for a description of the decoding process.
func Decode(data string, v interface{}) (MetaData, error) {
return NewDecoder(strings.NewReader(data)).Decode(v)
}
// DecodeFile reads the contents of a file and decodes it with [Decode].
func DecodeFile(path string, v interface{}) (MetaData, error) {
fp, err := os.Open(path)
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
}
// Primitive is a TOML value that hasn't been decoded into a Go value. // Primitive is a TOML value that hasn't been decoded into a Go value.
// //
// This type can be used for any value, which will cause decoding to be delayed. // This type can be used for any value, which will cause decoding to be delayed.
// You can use the PrimitiveDecode() function to "manually" decode these values. // You can use [PrimitiveDecode] to "manually" decode these values.
// //
// NOTE: The underlying representation of a `Primitive` value is subject to // NOTE: The underlying representation of a `Primitive` value is subject to
// change. Do not rely on it. // change. Do not rely on it.
@@ -42,36 +65,22 @@ type Primitive struct {
// The significand precision for float32 and float64 is 24 and 53 bits; this is // The significand precision for float32 and float64 is 24 and 53 bits; this is
// the range a natural number can be stored in a float without loss of data. // the range a natural number can be stored in a float without loss of data.
const ( const (
maxSafeFloat32Int = 16777215 // 2^24-1 maxSafeFloat32Int = 16777215 // 2^24-1
maxSafeFloat64Int = 9007199254740991 // 2^53-1 maxSafeFloat64Int = int64(9007199254740991) // 2^53-1
) )
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decoder decodes TOML data. // Decoder decodes TOML data.
// //
// TOML tables correspond to Go structs or maps (dealer's choice they can be // TOML tables correspond to Go structs or maps; they can be used
// used interchangeably). // interchangeably, but structs offer better type safety.
// //
// TOML table arrays correspond to either a slice of structs or a slice of maps. // TOML table arrays correspond to either a slice of structs or a slice of maps.
// //
// TOML datetimes correspond to Go time.Time values. Local datetimes are parsed // TOML datetimes correspond to [time.Time]. Local datetimes are parsed in the
// in the local timezone. // local timezone.
//
// [time.Duration] types are treated as nanoseconds if the TOML value is an
// integer, or they're parsed with time.ParseDuration() if they're strings.
// //
// All other TOML types (float, string, int, bool and array) correspond to the // All other TOML types (float, string, int, bool and array) correspond to the
// obvious Go types. // obvious Go types.
@@ -80,9 +89,9 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
// interface, in which case any primitive TOML value (floats, strings, integers, // interface, in which case any primitive TOML value (floats, strings, integers,
// booleans, datetimes) will be converted to a []byte and given to the value's // booleans, datetimes) will be converted to a []byte and given to the value's
// UnmarshalText method. See the Unmarshaler example for a demonstration with // UnmarshalText method. See the Unmarshaler example for a demonstration with
// time duration strings. // email addresses.
// //
// Key mapping // ### Key mapping
// //
// TOML keys can map to either keys in a Go map or field names in a Go struct. // TOML keys can map to either keys in a Go map or field names in a Go struct.
// The special `toml` struct tag can be used to map TOML keys to struct fields // The special `toml` struct tag can be used to map TOML keys to struct fields
@@ -109,6 +118,7 @@ func NewDecoder(r io.Reader) *Decoder {
var ( var (
unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem() unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
primitiveType = reflect.TypeOf((*Primitive)(nil)).Elem()
) )
// Decode TOML data in to the pointer `v`. // Decode TOML data in to the pointer `v`.
@@ -120,10 +130,10 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
s = "%v" s = "%v"
} }
return MetaData{}, e("cannot decode to non-pointer "+s, reflect.TypeOf(v)) return MetaData{}, fmt.Errorf("toml: cannot decode to non-pointer "+s, reflect.TypeOf(v))
} }
if rv.IsNil() { if rv.IsNil() {
return MetaData{}, e("cannot decode to nil value of %q", reflect.TypeOf(v)) return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v))
} }
// Check if this is a supported type: struct, map, interface{}, or something // Check if this is a supported type: struct, map, interface{}, or something
@@ -133,7 +143,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map && if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) && !(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) &&
!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) { !rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) {
return MetaData{}, e("cannot decode to type %s", rt) return MetaData{}, fmt.Errorf("toml: cannot decode to type %s", rt)
} }
// TODO: parser should read from io.Reader? Or at the very least, make it // TODO: parser should read from io.Reader? Or at the very least, make it
@@ -150,30 +160,29 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
md := MetaData{ md := MetaData{
mapping: p.mapping, mapping: p.mapping,
types: p.types, keyInfo: p.keyInfo,
keys: p.ordered, keys: p.ordered,
decoded: make(map[string]struct{}, len(p.ordered)), decoded: make(map[string]struct{}, len(p.ordered)),
context: nil, context: nil,
data: data,
} }
return md, md.unify(p.mapping, rv) return md, md.unify(p.mapping, rv)
} }
// Decode the TOML data in to the pointer v. // PrimitiveDecode is just like the other Decode* functions, except it decodes a
// TOML value that has already been parsed. Valid primitive values can *only* be
// obtained from values filled by the decoder functions, including this method.
// (i.e., v may contain more [Primitive] values.)
// //
// See the documentation on Decoder for a description of the decoding process. // Meta data for primitive values is included in the meta data returned by the
func Decode(data string, v interface{}) (MetaData, error) { // Decode* functions with one exception: keys returned by the Undecoded method
return NewDecoder(strings.NewReader(data)).Decode(v) // will only reflect keys that were decoded. Namely, any keys hidden behind a
} // Primitive will be considered undecoded. Executing this method will update the
// undecoded keys in the meta data. (See the example.)
// DecodeFile is just like Decode, except it will automatically read the func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
// contents of the file at path and decode it for you. md.context = primValue.context
func DecodeFile(path string, v interface{}) (MetaData, error) { defer func() { md.context = nil }()
fp, err := os.Open(path) return md.unify(primValue.undecoded, rvalue(v))
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
} }
// unify performs a sort of type unification based on the structure of `rv`, // unify performs a sort of type unification based on the structure of `rv`,
@@ -184,7 +193,7 @@ func DecodeFile(path string, v interface{}) (MetaData, error) {
func (md *MetaData) unify(data interface{}, rv reflect.Value) error { func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value. // Special case. Look for a `Primitive` value.
// TODO: #76 would make this superfluous after implemented. // TODO: #76 would make this superfluous after implemented.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() { if rv.Type() == primitiveType {
// Save the undecoded data and the key context into the primitive // Save the undecoded data and the key context into the primitive
// value. // value.
context := make(Key, len(md.context)) context := make(Key, len(md.context))
@@ -196,17 +205,14 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
return nil return nil
} }
// Special case. Unmarshaler Interface support. rvi := rv.Interface()
if rv.CanAddr() { if v, ok := rvi.(Unmarshaler); ok {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok { return v.UnmarshalTOML(data)
return v.UnmarshalTOML(data)
}
} }
if v, ok := rvi.(encoding.TextUnmarshaler); ok {
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
return md.unifyText(data, v) return md.unifyText(data, v)
} }
// TODO: // TODO:
// The behavior here is incorrect whenever a Go type satisfies the // The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML hash or // encoding.TextUnmarshaler interface but also corresponds to a TOML hash or
@@ -217,7 +223,6 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
k := rv.Kind() k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 { if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv) return md.unifyInt(data, rv)
} }
@@ -243,15 +248,14 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Bool: case reflect.Bool:
return md.unifyBool(data, rv) return md.unifyBool(data, rv)
case reflect.Interface: case reflect.Interface:
// we only support empty interfaces. if rv.NumMethod() > 0 { // Only support empty interfaces are supported.
if rv.NumMethod() > 0 { return md.e("unsupported type %s", rv.Type())
return e("unsupported type %s", rv.Type())
} }
return md.unifyAnything(data, rv) return md.unifyAnything(data, rv)
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return md.unifyFloat64(data, rv) return md.unifyFloat64(data, rv)
} }
return e("unsupported type %s", rv.Kind()) return md.e("unsupported type %s", rv.Kind())
} }
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
@@ -260,7 +264,7 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
if mapping == nil { if mapping == nil {
return nil return nil
} }
return e("type mismatch for %s: expected table but found %T", return md.e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping) rv.Type().String(), mapping)
} }
@@ -286,13 +290,14 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
if isUnifiable(subv) { if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = struct{}{} md.decoded[md.context.add(key).String()] = struct{}{}
md.context = append(md.context, key) md.context = append(md.context, key)
err := md.unify(datum, subv) err := md.unify(datum, subv)
if err != nil { if err != nil {
return err return err
} }
md.context = md.context[0 : len(md.context)-1] md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" { } else if f.name != "" {
return e("cannot write unexported field %s.%s", rv.Type().String(), f.name) return md.e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
} }
} }
} }
@@ -300,10 +305,10 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
} }
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
if k := rv.Type().Key().Kind(); k != reflect.String { keyType := rv.Type().Key().Kind()
return fmt.Errorf( if keyType != reflect.String && keyType != reflect.Interface {
"toml: cannot decode to a map with non-string key type (%s in %q)", return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)",
k, rv.Type()) keyType, rv.Type())
} }
tmap, ok := mapping.(map[string]interface{}) tmap, ok := mapping.(map[string]interface{})
@@ -321,13 +326,22 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
md.context = append(md.context, k) md.context = append(md.context, k)
rvval := reflect.Indirect(reflect.New(rv.Type().Elem())) rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
err := md.unify(v, indirect(rvval))
if err != nil {
return err return err
} }
md.context = md.context[0 : len(md.context)-1] md.context = md.context[0 : len(md.context)-1]
rvkey := indirect(reflect.New(rv.Type().Key())) rvkey := indirect(reflect.New(rv.Type().Key()))
rvkey.SetString(k)
switch keyType {
case reflect.Interface:
rvkey.Set(reflect.ValueOf(k))
case reflect.String:
rvkey.SetString(k)
}
rv.SetMapIndex(rvkey, rvval) rv.SetMapIndex(rvkey, rvval)
} }
return nil return nil
@@ -342,7 +356,7 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
return md.badtype("slice", data) return md.badtype("slice", data)
} }
if l := datav.Len(); l != rv.Len() { if l := datav.Len(); l != rv.Len() {
return e("expected array length %d; got TOML array of length %d", rv.Len(), l) return md.e("expected array length %d; got TOML array of length %d", rv.Len(), l)
} }
return md.unifySliceArray(datav, rv) return md.unifySliceArray(datav, rv)
} }
@@ -375,6 +389,18 @@ func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
} }
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error { func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
_, ok := rv.Interface().(json.Number)
if ok {
if i, ok := data.(int64); ok {
rv.SetString(strconv.FormatInt(i, 10))
} else if f, ok := data.(float64); ok {
rv.SetString(strconv.FormatFloat(f, 'f', -1, 64))
} else {
return md.badtype("string", data)
}
return nil
}
if s, ok := data.(string); ok { if s, ok := data.(string); ok {
rv.SetString(s) rv.SetString(s)
return nil return nil
@@ -383,11 +409,13 @@ func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
} }
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error { func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
rvk := rv.Kind()
if num, ok := data.(float64); ok { if num, ok := data.(float64); ok {
switch rv.Kind() { switch rvk {
case reflect.Float32: case reflect.Float32:
if num < -math.MaxFloat32 || num > math.MaxFloat32 { if num < -math.MaxFloat32 || num > math.MaxFloat32 {
return e("value %f is out of range for float32", num) return md.parseErr(errParseRange{i: num, size: rvk.String()})
} }
fallthrough fallthrough
case reflect.Float64: case reflect.Float64:
@@ -399,20 +427,11 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
} }
if num, ok := data.(int64); ok { if num, ok := data.(int64); ok {
switch rv.Kind() { if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) ||
case reflect.Float32: (rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) {
if num < -maxSafeFloat32Int || num > maxSafeFloat32Int { return md.parseErr(errParseRange{i: num, size: rvk.String()})
return e("value %d is out of range for float32", num)
}
fallthrough
case reflect.Float64:
if num < -maxSafeFloat64Int || num > maxSafeFloat64Int {
return e("value %d is out of range for float64", num)
}
rv.SetFloat(float64(num))
default:
panic("bug")
} }
rv.SetFloat(float64(num))
return nil return nil
} }
@@ -420,50 +439,46 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
} }
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error { func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok { _, ok := rv.Interface().(time.Duration)
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 { if ok {
switch rv.Kind() { // Parse as string duration, and fall back to regular integer parsing
case reflect.Int, reflect.Int64: // (as nanosecond) if this is not a string.
// No bounds checking necessary. if s, ok := data.(string); ok {
case reflect.Int8: dur, err := time.ParseDuration(s)
if num < math.MinInt8 || num > math.MaxInt8 { if err != nil {
return e("value %d is out of range for int8", num) return md.parseErr(errParseDuration{s})
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("value %d is out of range for int16", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("value %d is out of range for int32", num)
}
} }
rv.SetInt(num) rv.SetInt(int64(dur))
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 { return nil
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("value %d is out of range for uint8", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("value %d is out of range for uint16", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("value %d is out of range for uint32", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
} }
return nil
} }
return md.badtype("integer", data)
num, ok := data.(int64)
if !ok {
return md.badtype("integer", data)
}
rvk := rv.Kind()
switch {
case rvk >= reflect.Int && rvk <= reflect.Int64:
if (rvk == reflect.Int8 && (num < math.MinInt8 || num > math.MaxInt8)) ||
(rvk == reflect.Int16 && (num < math.MinInt16 || num > math.MaxInt16)) ||
(rvk == reflect.Int32 && (num < math.MinInt32 || num > math.MaxInt32)) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
rv.SetInt(num)
case rvk >= reflect.Uint && rvk <= reflect.Uint64:
unum := uint64(num)
if rvk == reflect.Uint8 && (num < 0 || unum > math.MaxUint8) ||
rvk == reflect.Uint16 && (num < 0 || unum > math.MaxUint16) ||
rvk == reflect.Uint32 && (num < 0 || unum > math.MaxUint32) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
rv.SetUint(unum)
default:
panic("unreachable")
}
return nil
} }
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error { func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
@@ -488,7 +503,7 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
return err return err
} }
s = string(text) s = string(text)
case TextMarshaler: case encoding.TextMarshaler:
text, err := sdata.MarshalText() text, err := sdata.MarshalText()
if err != nil { if err != nil {
return err return err
@@ -514,7 +529,30 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
} }
func (md *MetaData) badtype(dst string, data interface{}) error { func (md *MetaData) badtype(dst string, data interface{}) error {
return e("incompatible types: TOML key %q has type %T; destination has type %s", md.context, data, dst) return md.e("incompatible types: TOML value has type %T; destination has type %s", data, dst)
}
func (md *MetaData) parseErr(err error) error {
k := md.context.String()
return ParseError{
LastKey: k,
Position: md.keyInfo[k].pos,
Line: md.keyInfo[k].pos.Line,
err: err,
input: string(md.data),
}
}
func (md *MetaData) e(format string, args ...interface{}) error {
f := "toml: "
if len(md.context) > 0 {
f = fmt.Sprintf("toml: (last key %q): ", md.context)
p := md.keyInfo[md.context.String()].pos
if p.Line > 0 {
f = fmt.Sprintf("toml: line %d (last key %q): ", p.Line, md.context)
}
}
return fmt.Errorf(f+format, args...)
} }
// rvalue returns a reflect.Value of `v`. All pointers are resolved. // rvalue returns a reflect.Value of `v`. All pointers are resolved.
@@ -533,7 +571,11 @@ func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr { if v.Kind() != reflect.Ptr {
if v.CanSet() { if v.CanSet() {
pv := v.Addr() pv := v.Addr()
if _, ok := pv.Interface().(encoding.TextUnmarshaler); ok { pvi := pv.Interface()
if _, ok := pvi.(encoding.TextUnmarshaler); ok {
return pv
}
if _, ok := pvi.(Unmarshaler); ok {
return pv return pv
} }
} }
@@ -549,12 +591,12 @@ func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() { if rv.CanSet() {
return true return true
} }
if _, ok := rv.Interface().(encoding.TextUnmarshaler); ok { rvi := rv.Interface()
if _, ok := rvi.(encoding.TextUnmarshaler); ok {
return true
}
if _, ok := rvi.(Unmarshaler); ok {
return true return true
} }
return false return false
} }
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}

View File

@@ -7,8 +7,8 @@ import (
"io/fs" "io/fs"
) )
// DecodeFS is just like Decode, except it will automatically read the contents // DecodeFS reads the contents of a file from [fs.FS] and decodes it with
// of the file at `path` from a fs.FS instance. // [Decode].
func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) { func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) {
fp, err := fsys.Open(path) fp, err := fsys.Open(path)
if err != nil { if err != nil {

View File

@@ -1,13 +1,11 @@
/* // Package toml implements decoding and encoding of TOML files.
Package toml implements decoding and encoding of TOML files. //
// This package supports TOML v1.0.0, as specified at https://toml.io
This package supports TOML v1.0.0, as listed on https://toml.io //
// There is also support for delaying decoding with the Primitive type, and
There is also support for delaying decoding with the Primitive type, and // querying the set of keys in a TOML document with the MetaData type.
querying the set of keys in a TOML document with the MetaData type. //
// The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator,
The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator, // and can be used to verify if TOML document is valid. It can also be used to
and can be used to verify if TOML document is valid. It can also be used to // print the type of each key.
print the type of each key.
*/
package toml package toml

View File

@@ -3,6 +3,7 @@ package toml
import ( import (
"bufio" "bufio"
"encoding" "encoding"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -63,6 +64,12 @@ var dblQuotedReplacer = strings.NewReplacer(
"\x7f", `\u007f`, "\x7f", `\u007f`,
) )
var (
marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
)
// Marshaler is the interface implemented by types that can marshal themselves // Marshaler is the interface implemented by types that can marshal themselves
// into valid TOML. // into valid TOML.
type Marshaler interface { type Marshaler interface {
@@ -72,9 +79,12 @@ type Marshaler interface {
// Encoder encodes a Go to a TOML document. // Encoder encodes a Go to a TOML document.
// //
// The mapping between Go values and TOML values should be precisely the same as // The mapping between Go values and TOML values should be precisely the same as
// for the Decode* functions. // for [Decode].
// //
// The toml.Marshaler and encoder.TextMarshaler interfaces are supported to // time.Time is encoded as a RFC 3339 string, and time.Duration as its string
// representation.
//
// The [Marshaler] and [encoding.TextMarshaler] interfaces are supported to
// encoding the value as custom TOML. // encoding the value as custom TOML.
// //
// If you want to write arbitrary binary data then you will need to use // If you want to write arbitrary binary data then you will need to use
@@ -85,6 +95,17 @@ type Marshaler interface {
// //
// Go maps will be sorted alphabetically by key for deterministic output. // Go maps will be sorted alphabetically by key for deterministic output.
// //
// The toml struct tag can be used to provide the key name; if omitted the
// struct field name will be used. If the "omitempty" option is present the
// following value will be skipped:
//
// - arrays, slices, maps, and string with len of 0
// - struct with all zero values
// - bool false
//
// If omitzero is given all int and float types with a value of 0 will be
// skipped.
//
// Encoding Go values without a corresponding TOML representation will return an // Encoding Go values without a corresponding TOML representation will return an
// error. Examples of this includes maps with non-string keys, slices with nil // error. Examples of this includes maps with non-string keys, slices with nil
// elements, embedded non-struct types, and nested slices containing maps or // elements, embedded non-struct types, and nested slices containing maps or
@@ -109,7 +130,7 @@ func NewEncoder(w io.Writer) *Encoder {
} }
} }
// Encode writes a TOML representation of the Go value to the Encoder's writer. // Encode writes a TOML representation of the Go value to the [Encoder]'s writer.
// //
// An error is returned if the value given cannot be encoded to a valid TOML // An error is returned if the value given cannot be encoded to a valid TOML
// document. // document.
@@ -136,18 +157,15 @@ func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
} }
func (enc *Encoder) encode(key Key, rv reflect.Value) { func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case: time needs to be in ISO8601 format. // If we can marshal the type to text, then we use that. This prevents the
// // encoder for handling these types as generic structs (or whatever the
// Special case: if we can marshal the type to text, then we used that. This // underlying type of a TextMarshaler is).
// prevents the encoder for handling these types as generic structs (or switch {
// whatever the underlying type of a TextMarshaler is). case isMarshaler(rv):
switch t := rv.Interface().(type) {
case time.Time, encoding.TextMarshaler, Marshaler:
enc.writeKeyValue(key, rv, false) enc.writeKeyValue(key, rv, false)
return return
// TODO: #76 would make this superfluous after implemented. case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented.
case Primitive: enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded))
enc.encode(key, reflect.ValueOf(t.undecoded))
return return
} }
@@ -212,18 +230,44 @@ func (enc *Encoder) eElement(rv reflect.Value) {
if err != nil { if err != nil {
encPanic(err) encPanic(err)
} }
enc.writeQuoted(string(s)) if s == nil {
encPanic(errors.New("MarshalTOML returned nil and no error"))
}
enc.w.Write(s)
return return
case encoding.TextMarshaler: case encoding.TextMarshaler:
s, err := v.MarshalText() s, err := v.MarshalText()
if err != nil { if err != nil {
encPanic(err) encPanic(err)
} }
if s == nil {
encPanic(errors.New("MarshalText returned nil and no error"))
}
enc.writeQuoted(string(s)) enc.writeQuoted(string(s))
return return
case time.Duration:
enc.writeQuoted(v.String())
return
case json.Number:
n, _ := rv.Interface().(json.Number)
if n == "" { /// Useful zero value.
enc.w.WriteByte('0')
return
} else if v, err := n.Int64(); err == nil {
enc.eElement(reflect.ValueOf(v))
return
} else if v, err := n.Float64(); err == nil {
enc.eElement(reflect.ValueOf(v))
return
}
encPanic(fmt.Errorf("unable to convert %q to int64 or float64", n))
} }
switch rv.Kind() { switch rv.Kind() {
case reflect.Ptr:
enc.eElement(rv.Elem())
return
case reflect.String: case reflect.String:
enc.writeQuoted(rv.String()) enc.writeQuoted(rv.String())
case reflect.Bool: case reflect.Bool:
@@ -259,7 +303,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
case reflect.Interface: case reflect.Interface:
enc.eElement(rv.Elem()) enc.eElement(rv.Elem())
default: default:
encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface())) encPanic(fmt.Errorf("unexpected type: %T", rv.Interface()))
} }
} }
@@ -280,7 +324,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len() length := rv.Len()
enc.wf("[") enc.wf("[")
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
elem := rv.Index(i) elem := eindirect(rv.Index(i))
enc.eElement(elem) enc.eElement(elem)
if i != length-1 { if i != length-1 {
enc.wf(", ") enc.wf(", ")
@@ -294,7 +338,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
encPanic(errNoKey) encPanic(errNoKey)
} }
for i := 0; i < rv.Len(); i++ { for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i) trv := eindirect(rv.Index(i))
if isNil(trv) { if isNil(trv) {
continue continue
} }
@@ -319,7 +363,7 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) {
} }
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) { func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) {
switch rv := eindirect(rv); rv.Kind() { switch rv.Kind() {
case reflect.Map: case reflect.Map:
enc.eMap(key, rv, inline) enc.eMap(key, rv, inline)
case reflect.Struct: case reflect.Struct:
@@ -341,7 +385,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var mapKeysDirect, mapKeysSub []string var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() { for _, mapKey := range rv.MapKeys() {
k := mapKey.String() k := mapKey.String()
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) { if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) {
mapKeysSub = append(mapKeysSub, k) mapKeysSub = append(mapKeysSub, k)
} else { } else {
mapKeysDirect = append(mapKeysDirect, k) mapKeysDirect = append(mapKeysDirect, k)
@@ -351,7 +395,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var writeMapKeys = func(mapKeys []string, trailC bool) { var writeMapKeys = func(mapKeys []string, trailC bool) {
sort.Strings(mapKeys) sort.Strings(mapKeys)
for i, mapKey := range mapKeys { for i, mapKey := range mapKeys {
val := rv.MapIndex(reflect.ValueOf(mapKey)) val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey)))
if isNil(val) { if isNil(val) {
continue continue
} }
@@ -379,6 +423,13 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
const is32Bit = (32 << (^uint(0) >> 63)) == 32 const is32Bit = (32 << (^uint(0) >> 63)) == 32
func pointerTo(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
return pointerTo(t.Elem())
}
return t
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) { func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
// Write keys for fields directly under this key first, because if we write // Write keys for fields directly under this key first, because if we write
// a field that creates a new table then all keys under it will be in that // a field that creates a new table then all keys under it will be in that
@@ -395,31 +446,25 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
addFields = func(rt reflect.Type, rv reflect.Value, start []int) { addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ { for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i) f := rt.Field(i)
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields. isEmbed := f.Anonymous && pointerTo(f.Type).Kind() == reflect.Struct
if f.PkgPath != "" && !isEmbed { /// Skip unexported fields.
continue
}
opts := getOptions(f.Tag)
if opts.skip {
continue continue
} }
frv := rv.Field(i) frv := eindirect(rv.Field(i))
// Treat anonymous struct fields with tag names as though they are // Treat anonymous struct fields with tag names as though they are
// not anonymous, like encoding/json does. // not anonymous, like encoding/json does.
// //
// Non-struct anonymous fields use the normal encoding logic. // Non-struct anonymous fields use the normal encoding logic.
if f.Anonymous { if isEmbed {
t := f.Type if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct {
switch t.Kind() { addFields(frv.Type(), frv, append(start, f.Index...))
case reflect.Struct: continue
if getOptions(f.Tag).name == "" {
addFields(t, frv, append(start, f.Index...))
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), append(start, f.Index...))
}
continue
}
} }
} }
@@ -445,7 +490,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
writeFields := func(fields [][]int) { writeFields := func(fields [][]int) {
for _, fieldIndex := range fields { for _, fieldIndex := range fields {
fieldType := rt.FieldByIndex(fieldIndex) fieldType := rt.FieldByIndex(fieldIndex)
fieldVal := rv.FieldByIndex(fieldIndex) fieldVal := eindirect(rv.FieldByIndex(fieldIndex))
if isNil(fieldVal) { /// Don't write anything for nil fields. if isNil(fieldVal) { /// Don't write anything for nil fields.
continue continue
@@ -459,7 +504,8 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
if opts.name != "" { if opts.name != "" {
keyName = opts.name keyName = opts.name
} }
if opts.omitempty && isEmpty(fieldVal) {
if opts.omitempty && enc.isEmpty(fieldVal) {
continue continue
} }
if opts.omitzero && isZero(fieldVal) { if opts.omitzero && isZero(fieldVal) {
@@ -498,6 +544,21 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() { if isNil(rv) || !rv.IsValid() {
return nil return nil
} }
if rv.Kind() == reflect.Struct {
if rv.Type() == timeType {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
}
if isMarshaler(rv) {
return tomlString
}
switch rv.Kind() { switch rv.Kind() {
case reflect.Bool: case reflect.Bool:
return tomlBool return tomlBool
@@ -509,7 +570,7 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return tomlFloat return tomlFloat
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) { if isTableArray(rv) {
return tomlArrayHash return tomlArrayHash
} }
return tomlArray return tomlArray
@@ -519,67 +580,35 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
return tomlString return tomlString
case reflect.Map: case reflect.Map:
return tomlHash return tomlHash
case reflect.Struct:
if _, ok := rv.Interface().(time.Time); ok {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
default: default:
if isMarshaler(rv) {
return tomlString
}
encPanic(errors.New("unsupported type: " + rv.Kind().String())) encPanic(errors.New("unsupported type: " + rv.Kind().String()))
panic("unreachable") panic("unreachable")
} }
} }
func isMarshaler(rv reflect.Value) bool { func isMarshaler(rv reflect.Value) bool {
switch rv.Interface().(type) { return rv.Type().Implements(marshalText) || rv.Type().Implements(marshalToml)
case encoding.TextMarshaler:
return true
case Marshaler:
return true
}
// Someone used a pointer receiver: we can make it work for pointer values.
if rv.CanAddr() {
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok {
return true
}
if _, ok := rv.Addr().Interface().(Marshaler); ok {
return true
}
}
return false
} }
// tomlArrayType returns the element type of a TOML array. The type returned // isTableArray reports if all entries in the array or slice are a table.
// may be nil if it cannot be determined (e.g., a nil slice or a zero length func isTableArray(arr reflect.Value) bool {
// slize). This function may also panic if it finds a type that cannot be if isNil(arr) || !arr.IsValid() || arr.Len() == 0 {
// expressed in TOML (such as nil elements, heterogeneous arrays or directly return false
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
} }
/// Don't allow nil. ret := true
rvlen := rv.Len() for i := 0; i < arr.Len(); i++ {
for i := 1; i < rvlen; i++ { tt := tomlTypeOfGo(eindirect(arr.Index(i)))
if tomlTypeOfGo(rv.Index(i)) == nil { // Don't allow nil.
if tt == nil {
encPanic(errArrayNilElement) encPanic(errArrayNilElement)
} }
}
firstType := tomlTypeOfGo(rv.Index(0)) if ret && !typeEqual(tomlHash, tt) {
if firstType == nil { ret = false
encPanic(errArrayNilElement) }
} }
return firstType return ret
} }
type tagOptions struct { type tagOptions struct {
@@ -620,10 +649,26 @@ func isZero(rv reflect.Value) bool {
return false return false
} }
func isEmpty(rv reflect.Value) bool { func (enc *Encoder) isEmpty(rv reflect.Value) bool {
switch rv.Kind() { switch rv.Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.String: case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return rv.Len() == 0 return rv.Len() == 0
case reflect.Struct:
if rv.Type().Comparable() {
return reflect.Zero(rv.Type()).Interface() == rv.Interface()
}
// Need to also check if all the fields are empty, otherwise something
// like this with uncomparable types will always return true:
//
// type a struct{ field b }
// type b struct{ s []string }
// s := a{field: b{s: []string{"AAA"}}}
for i := 0; i < rv.NumField(); i++ {
if !enc.isEmpty(rv.Field(i)) {
return false
}
}
return true
case reflect.Bool: case reflect.Bool:
return !rv.Bool() return !rv.Bool()
} }
@@ -638,16 +683,15 @@ func (enc *Encoder) newline() {
// Write a key/value pair: // Write a key/value pair:
// //
// key = <any value> // key = <any value>
// //
// This is also used for "k = v" in inline tables; so something like this will // This is also used for "k = v" in inline tables; so something like this will
// be written in three calls: // be written in three calls:
// //
// ┌────────────────────┐ //───────────────────┐
// │ ┌───┐ ┌────┐│ // │ ┌───┐ ┌────┐│
// v v v v vv // v v v v vv
// key = {k = v, k2 = v2} // key = {k = 1, k2 = 2}
//
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) { func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
if len(key) == 0 { if len(key) == 0 {
encPanic(errNoKey) encPanic(errNoKey)
@@ -675,13 +719,25 @@ func encPanic(err error) {
panic(tomlEncodeError{err}) panic(tomlEncodeError{err})
} }
// Resolve any level of pointers to the actual value (e.g. **string → string).
func eindirect(v reflect.Value) reflect.Value { func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() { if v.Kind() != reflect.Ptr && v.Kind() != reflect.Interface {
case reflect.Ptr, reflect.Interface: if isMarshaler(v) {
return eindirect(v.Elem()) return v
default: }
if v.CanAddr() { /// Special case for marshalers; see #358.
if pv := v.Addr(); isMarshaler(pv) {
return pv
}
}
return v return v
} }
if v.IsNil() {
return v
}
return eindirect(v.Elem())
} }
func isNil(rv reflect.Value) bool { func isNil(rv reflect.Value) bool {

View File

@@ -5,57 +5,60 @@ import (
"strings" "strings"
) )
// ParseError is returned when there is an error parsing the TOML syntax. // ParseError is returned when there is an error parsing the TOML syntax such as
// // invalid syntax, duplicate keys, etc.
// For example invalid syntax, duplicate keys, etc.
// //
// In addition to the error message itself, you can also print detailed location // In addition to the error message itself, you can also print detailed location
// information with context by using ErrorWithLocation(): // information with context by using [ErrorWithPosition]:
// //
// toml: error: Key 'fruit' was already created and cannot be used as an array. // toml: error: Key 'fruit' was already created and cannot be used as an array.
// //
// At line 4, column 2-7: // At line 4, column 2-7:
// //
// 2 | fruit = [] // 2 | fruit = []
// 3 | // 3 |
// 4 | [[fruit]] # Not allowed // 4 | [[fruit]] # Not allowed
// ^^^^^ // ^^^^^
// //
// Furthermore, the ErrorWithUsage() can be used to print the above with some // [ErrorWithUsage] can be used to print the above with some more detailed usage
// more detailed usage guidance: // guidance:
// //
// toml: error: newlines not allowed within inline tables // toml: error: newlines not allowed within inline tables
// //
// At line 1, column 18: // At line 1, column 18:
// //
// 1 | x = [{ key = 42 # // 1 | x = [{ key = 42 #
// ^ // ^
// //
// Error help: // Error help:
// //
// Inline tables must always be on a single line: // Inline tables must always be on a single line:
// //
// table = {key = 42, second = 43} // table = {key = 42, second = 43}
// //
// It is invalid to split them over multiple lines like so: // It is invalid to split them over multiple lines like so:
// //
// # INVALID // # INVALID
// table = { // table = {
// key = 42, // key = 42,
// second = 43 // second = 43
// } // }
// //
// Use regular for this: // Use regular for this:
// //
// [table] // [table]
// key = 42 // key = 42
// second = 43 // second = 43
type ParseError struct { type ParseError struct {
Message string // Short technical message. Message string // Short technical message.
Usage string // Longer message with usage guidance; may be blank. Usage string // Longer message with usage guidance; may be blank.
Position Position // Position of the error Position Position // Position of the error
LastKey string // Last parsed key, may be blank. LastKey string // Last parsed key, may be blank.
Line int // Line the error occurred. Deprecated: use Position.
// Line the error occurred.
//
// Deprecated: use [Position].
Line int
err error err error
input string input string
@@ -83,7 +86,7 @@ func (pe ParseError) Error() string {
// ErrorWithUsage() returns the error with detailed location context. // ErrorWithUsage() returns the error with detailed location context.
// //
// See the documentation on ParseError. // See the documentation on [ParseError].
func (pe ParseError) ErrorWithPosition() string { func (pe ParseError) ErrorWithPosition() string {
if pe.input == "" { // Should never happen, but just in case. if pe.input == "" { // Should never happen, but just in case.
return pe.Error() return pe.Error()
@@ -124,13 +127,17 @@ func (pe ParseError) ErrorWithPosition() string {
// ErrorWithUsage() returns the error with detailed location context and usage // ErrorWithUsage() returns the error with detailed location context and usage
// guidance. // guidance.
// //
// See the documentation on ParseError. // See the documentation on [ParseError].
func (pe ParseError) ErrorWithUsage() string { func (pe ParseError) ErrorWithUsage() string {
m := pe.ErrorWithPosition() m := pe.ErrorWithPosition()
if u, ok := pe.err.(interface{ Usage() string }); ok && u.Usage() != "" { if u, ok := pe.err.(interface{ Usage() string }); ok && u.Usage() != "" {
return m + "Error help:\n\n " + lines := strings.Split(strings.TrimSpace(u.Usage()), "\n")
strings.ReplaceAll(strings.TrimSpace(u.Usage()), "\n", "\n ") + for i := range lines {
"\n" if lines[i] != "" {
lines[i] = " " + lines[i]
}
}
return m + "Error help:\n\n" + strings.Join(lines, "\n") + "\n"
} }
return m return m
} }
@@ -160,6 +167,11 @@ type (
errLexInvalidDate struct{ v string } errLexInvalidDate struct{ v string }
errLexInlineTableNL struct{} errLexInlineTableNL struct{}
errLexStringNL struct{} errLexStringNL struct{}
errParseRange struct {
i interface{} // int or float
size string // "int64", "uint16", etc.
}
errParseDuration struct{ d string }
) )
func (e errLexControl) Error() string { func (e errLexControl) Error() string {
@@ -179,6 +191,10 @@ func (e errLexInlineTableNL) Error() string { return "newlines not allowed withi
func (e errLexInlineTableNL) Usage() string { return usageInlineNewline } func (e errLexInlineTableNL) Usage() string { return usageInlineNewline }
func (e errLexStringNL) Error() string { return "strings cannot contain newlines" } func (e errLexStringNL) Error() string { return "strings cannot contain newlines" }
func (e errLexStringNL) Usage() string { return usageStringNewline } func (e errLexStringNL) Usage() string { return usageStringNewline }
func (e errParseRange) Error() string { return fmt.Sprintf("%v is out of range for %s", e.i, e.size) }
func (e errParseRange) Usage() string { return usageIntOverflow }
func (e errParseDuration) Error() string { return fmt.Sprintf("invalid duration: %q", e.d) }
func (e errParseDuration) Usage() string { return usageDuration }
const usageEscape = ` const usageEscape = `
A '\' inside a "-delimited string is interpreted as an escape character. A '\' inside a "-delimited string is interpreted as an escape character.
@@ -227,3 +243,37 @@ Instead use """ or ''' to split strings over multiple lines:
string = """Hello, string = """Hello,
world!""" world!"""
` `
const usageIntOverflow = `
This number is too large; this may be an error in the TOML, but it can also be a
bug in the program that uses too small of an integer.
The maximum and minimum values are:
size │ lowest │ highest
───────┼────────────────┼──────────
int8 │ -128 │ 127
int16 │ -32,768 │ 32,767
int32 │ -2,147,483,648 │ 2,147,483,647
int64 │ -9.2 × 10¹⁷ │ 9.2 × 10¹⁷
uint8 │ 0 │ 255
uint16 │ 0 │ 65535
uint32 │ 0 │ 4294967295
uint64 │ 0 │ 1.8 × 10¹⁸
int refers to int32 on 32-bit systems and int64 on 64-bit systems.
`
const usageDuration = `
A duration must be as "number<unit>", without any spaces. Valid units are:
ns nanoseconds (billionth of a second)
us, µs microseconds (millionth of a second)
ms milliseconds (thousands of a second)
s seconds
m minutes
h hours
You can combine multiple units; for example "5m10s" for 5 minutes and 10
seconds.
`

View File

@@ -82,7 +82,7 @@ func (lx *lexer) nextItem() item {
return item return item
default: default:
lx.state = lx.state(lx) lx.state = lx.state(lx)
//fmt.Printf(" STATE %-24s current: %-10q stack: %s\n", lx.state, lx.current(), lx.stack) //fmt.Printf(" STATE %-24s current: %-10s stack: %s\n", lx.state, lx.current(), lx.stack)
} }
} }
} }
@@ -128,6 +128,11 @@ func (lx lexer) getPos() Position {
} }
func (lx *lexer) emit(typ itemType) { func (lx *lexer) emit(typ itemType) {
// Needed for multiline strings ending with an incomplete UTF-8 sequence.
if lx.start > lx.pos {
lx.error(errLexUTF8{lx.input[lx.pos]})
return
}
lx.items <- item{typ: typ, pos: lx.getPos(), val: lx.current()} lx.items <- item{typ: typ, pos: lx.getPos(), val: lx.current()}
lx.start = lx.pos lx.start = lx.pos
} }
@@ -711,7 +716,17 @@ func lexMultilineString(lx *lexer) stateFn {
if lx.peek() == '"' { if lx.peek() == '"' {
/// Check if we already lexed 5 's; if so we have 6 now, and /// Check if we already lexed 5 's; if so we have 6 now, and
/// that's just too many man! /// that's just too many man!
if strings.HasSuffix(lx.current(), `"""""`) { ///
/// Second check is for the edge case:
///
/// two quotes allowed.
/// vv
/// """lol \""""""
/// ^^ ^^^---- closing three
/// escaped
///
/// But ugly, but it works
if strings.HasSuffix(lx.current(), `"""""`) && !strings.HasSuffix(lx.current(), `\"""""`) {
return lx.errorf(`unexpected '""""""'`) return lx.errorf(`unexpected '""""""'`)
} }
lx.backup() lx.backup()
@@ -756,7 +771,7 @@ func lexRawString(lx *lexer) stateFn {
} }
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such // lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'''" has already been consumed and // a string. It assumes that the beginning ''' has already been consumed and
// ignored. // ignored.
func lexMultilineRawString(lx *lexer) stateFn { func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next() r := lx.next()
@@ -802,8 +817,7 @@ func lexMultilineRawString(lx *lexer) stateFn {
// lexMultilineStringEscape consumes an escaped character. It assumes that the // lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed. // preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn { func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first: if isNL(lx.next()) { /// \ escaping newline.
if isNL(lx.next()) {
return lexMultilineString return lexMultilineString
} }
lx.backup() lx.backup()

View File

@@ -12,10 +12,11 @@ import (
type MetaData struct { type MetaData struct {
context Key // Used only during decoding. context Key // Used only during decoding.
keyInfo map[string]keyInfo
mapping map[string]interface{} mapping map[string]interface{}
types map[string]tomlType
keys []Key keys []Key
decoded map[string]struct{} decoded map[string]struct{}
data []byte // Input file; for errors.
} }
// IsDefined reports if the key exists in the TOML data. // IsDefined reports if the key exists in the TOML data.
@@ -50,8 +51,8 @@ func (md *MetaData) IsDefined(key ...string) bool {
// Type will return the empty string if given an empty key or a key that does // Type will return the empty string if given an empty key or a key that does
// not exist. Keys are case sensitive. // not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string { func (md *MetaData) Type(key ...string) string {
if typ, ok := md.types[Key(key).String()]; ok { if ki, ok := md.keyInfo[Key(key).String()]; ok {
return typ.typeString() return ki.tomlType.typeString()
} }
return "" return ""
} }
@@ -70,7 +71,7 @@ func (md *MetaData) Keys() []Key {
// Undecoded returns all keys that have not been decoded in the order in which // Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document. // they appear in the original TOML document.
// //
// This includes keys that haven't been decoded because of a Primitive value. // This includes keys that haven't been decoded because of a [Primitive] value.
// Once the Primitive value is decoded, the keys will be considered decoded. // Once the Primitive value is decoded, the keys will be considered decoded.
// //
// Also note that decoding into an empty interface will result in no decoding, // Also note that decoding into an empty interface will result in no decoding,
@@ -88,7 +89,7 @@ func (md *MetaData) Undecoded() []Key {
return undecoded return undecoded
} }
// Key represents any TOML key, including key groups. Use (MetaData).Keys to get // Key represents any TOML key, including key groups. Use [MetaData.Keys] to get
// values of this type. // values of this type.
type Key []string type Key []string

View File

@@ -16,12 +16,18 @@ type parser struct {
currentKey string // Base key name for everything except hashes. currentKey string // Base key name for everything except hashes.
pos Position // Current position in the TOML file. pos Position // Current position in the TOML file.
ordered []Key // List of keys in the order that they appear in the TOML data. ordered []Key // List of keys in the order that they appear in the TOML data.
keyInfo map[string]keyInfo // Map keyname → info about the TOML key.
mapping map[string]interface{} // Map keyname → key value. mapping map[string]interface{} // Map keyname → key value.
types map[string]tomlType // Map keyname → TOML type.
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names"). implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names").
} }
type keyInfo struct {
pos Position
tomlType tomlType
}
func parse(data string) (p *parser, err error) { func parse(data string) (p *parser, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -57,8 +63,8 @@ func parse(data string) (p *parser, err error) {
} }
p = &parser{ p = &parser{
keyInfo: make(map[string]keyInfo),
mapping: make(map[string]interface{}), mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data), lx: lex(data),
ordered: make([]Key, 0), ordered: make([]Key, 0),
implicits: make(map[string]struct{}), implicits: make(map[string]struct{}),
@@ -74,6 +80,15 @@ func parse(data string) (p *parser, err error) {
return p, nil return p, nil
} }
func (p *parser) panicErr(it item, err error) {
panic(ParseError{
err: err,
Position: it.pos,
Line: it.pos.Len,
LastKey: p.current(),
})
}
func (p *parser) panicItemf(it item, format string, v ...interface{}) { func (p *parser) panicItemf(it item, format string, v ...interface{}) {
panic(ParseError{ panic(ParseError{
Message: fmt.Sprintf(format, v...), Message: fmt.Sprintf(format, v...),
@@ -94,7 +109,7 @@ func (p *parser) panicf(format string, v ...interface{}) {
func (p *parser) next() item { func (p *parser) next() item {
it := p.lx.nextItem() it := p.lx.nextItem()
//fmt.Printf("ITEM %-18s line %-3d │ %q\n", it.typ, it.line, it.val) //fmt.Printf("ITEM %-18s line %-3d │ %q\n", it.typ, it.pos.Line, it.val)
if it.typ == itemError { if it.typ == itemError {
if it.err != nil { if it.err != nil {
panic(ParseError{ panic(ParseError{
@@ -146,7 +161,7 @@ func (p *parser) topLevel(item item) {
p.assertEqual(itemTableEnd, name.typ) p.assertEqual(itemTableEnd, name.typ)
p.addContext(key, false) p.addContext(key, false)
p.setType("", tomlHash) p.setType("", tomlHash, item.pos)
p.ordered = append(p.ordered, key) p.ordered = append(p.ordered, key)
case itemArrayTableStart: // [[ .. ]] case itemArrayTableStart: // [[ .. ]]
name := p.nextPos() name := p.nextPos()
@@ -158,7 +173,7 @@ func (p *parser) topLevel(item item) {
p.assertEqual(itemArrayTableEnd, name.typ) p.assertEqual(itemArrayTableEnd, name.typ)
p.addContext(key, true) p.addContext(key, true)
p.setType("", tomlArrayHash) p.setType("", tomlArrayHash, item.pos)
p.ordered = append(p.ordered, key) p.ordered = append(p.ordered, key)
case itemKeyStart: // key = .. case itemKeyStart: // key = ..
outerContext := p.context outerContext := p.context
@@ -181,8 +196,9 @@ func (p *parser) topLevel(item item) {
} }
/// Set value. /// Set value.
val, typ := p.value(p.next(), false) vItem := p.next()
p.set(p.currentKey, val, typ) val, typ := p.value(vItem, false)
p.set(p.currentKey, val, typ, vItem.pos)
p.ordered = append(p.ordered, p.context.add(p.currentKey)) p.ordered = append(p.ordered, p.context.add(p.currentKey))
/// Remove the context we added (preserving any context from [tbl] lines). /// Remove the context we added (preserving any context from [tbl] lines).
@@ -220,7 +236,7 @@ func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) {
case itemString: case itemString:
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it) return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it)
case itemMultilineString: case itemMultilineString:
return p.replaceEscapes(it, stripFirstNewline(stripEscapedNewlines(it.val))), p.typeOfPrimitive(it) return p.replaceEscapes(it, stripFirstNewline(p.stripEscapedNewlines(it.val))), p.typeOfPrimitive(it)
case itemRawString: case itemRawString:
return it.val, p.typeOfPrimitive(it) return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString: case itemRawMultilineString:
@@ -266,7 +282,7 @@ func (p *parser) valueInteger(it item) (interface{}, tomlType) {
// So mark the former as a bug but the latter as a legitimate user // So mark the former as a bug but the latter as a legitimate user
// error. // error.
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
p.panicItemf(it, "Integer '%s' is out of the range of 64-bit signed integers.", it.val) p.panicErr(it, errParseRange{i: it.val, size: "int64"})
} else { } else {
p.bug("Expected integer value, but got '%s'.", it.val) p.bug("Expected integer value, but got '%s'.", it.val)
} }
@@ -304,7 +320,7 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) {
num, err := strconv.ParseFloat(val, 64) num, err := strconv.ParseFloat(val, 64)
if err != nil { if err != nil {
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
p.panicItemf(it, "Float '%s' is out of the range of 64-bit IEEE-754 floating-point numbers.", it.val) p.panicErr(it, errParseRange{i: it.val, size: "float64"})
} else { } else {
p.panicItemf(it, "Invalid float value: %q", it.val) p.panicItemf(it, "Invalid float value: %q", it.val)
} }
@@ -343,9 +359,8 @@ func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
} }
func (p *parser) valueArray(it item) (interface{}, tomlType) { func (p *parser) valueArray(it item) (interface{}, tomlType) {
p.setType(p.currentKey, tomlArray) p.setType(p.currentKey, tomlArray, it.pos)
// p.setType(p.currentKey, typ)
var ( var (
types []tomlType types []tomlType
@@ -414,7 +429,7 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom
/// Set the value. /// Set the value.
val, typ := p.value(p.next(), false) val, typ := p.value(p.next(), false)
p.set(p.currentKey, val, typ) p.set(p.currentKey, val, typ, it.pos)
p.ordered = append(p.ordered, p.context.add(p.currentKey)) p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[p.currentKey] = val hash[p.currentKey] = val
@@ -533,9 +548,10 @@ func (p *parser) addContext(key Key, array bool) {
} }
// set calls setValue and setType. // set calls setValue and setType.
func (p *parser) set(key string, val interface{}, typ tomlType) { func (p *parser) set(key string, val interface{}, typ tomlType, pos Position) {
p.setValue(key, val) p.setValue(key, val)
p.setType(key, typ) p.setType(key, typ, pos)
} }
// setValue sets the given key to the given value in the current context. // setValue sets the given key to the given value in the current context.
@@ -599,7 +615,7 @@ func (p *parser) setValue(key string, value interface{}) {
// //
// Note that if `key` is empty, then the type given will be applied to the // Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables). // current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) { func (p *parser) setType(key string, typ tomlType, pos Position) {
keyContext := make(Key, 0, len(p.context)+1) keyContext := make(Key, 0, len(p.context)+1)
keyContext = append(keyContext, p.context...) keyContext = append(keyContext, p.context...)
if len(key) > 0 { // allow type setting for hashes if len(key) > 0 { // allow type setting for hashes
@@ -611,7 +627,7 @@ func (p *parser) setType(key string, typ tomlType) {
if len(keyContext) == 0 { if len(keyContext) == 0 {
keyContext = Key{""} keyContext = Key{""}
} }
p.types[keyContext.String()] = typ p.keyInfo[keyContext.String()] = keyInfo{tomlType: typ, pos: pos}
} }
// Implicit keys need to be created when tables are implied in "a.b.c.d = 1" and // Implicit keys need to be created when tables are implied in "a.b.c.d = 1" and
@@ -619,7 +635,7 @@ func (p *parser) setType(key string, typ tomlType) {
func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = struct{}{} } func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = struct{}{} }
func (p *parser) removeImplicit(key Key) { delete(p.implicits, key.String()) } func (p *parser) removeImplicit(key Key) { delete(p.implicits, key.String()) }
func (p *parser) isImplicit(key Key) bool { _, ok := p.implicits[key.String()]; return ok } func (p *parser) isImplicit(key Key) bool { _, ok := p.implicits[key.String()]; return ok }
func (p *parser) isArray(key Key) bool { return p.types[key.String()] == tomlArray } func (p *parser) isArray(key Key) bool { return p.keyInfo[key.String()].tomlType == tomlArray }
func (p *parser) addImplicitContext(key Key) { func (p *parser) addImplicitContext(key Key) {
p.addImplicit(key) p.addImplicit(key)
p.addContext(key, false) p.addContext(key, false)
@@ -647,7 +663,7 @@ func stripFirstNewline(s string) string {
} }
// Remove newlines inside triple-quoted strings if a line ends with "\". // Remove newlines inside triple-quoted strings if a line ends with "\".
func stripEscapedNewlines(s string) string { func (p *parser) stripEscapedNewlines(s string) string {
split := strings.Split(s, "\n") split := strings.Split(s, "\n")
if len(split) < 1 { if len(split) < 1 {
return s return s
@@ -679,6 +695,10 @@ func stripEscapedNewlines(s string) string {
continue continue
} }
if i == len(split)-1 {
p.panicf("invalid escape: '\\ '")
}
split[i] = line[:len(line)-1] // Remove \ split[i] = line[:len(line)-1] // Remove \
if len(split)-1 > i { if len(split)-1 > i {
split[i+1] = strings.TrimLeft(split[i+1], " \t\r") split[i+1] = strings.TrimLeft(split[i+1], " \t\r")
@@ -706,10 +726,8 @@ func (p *parser) replaceEscapes(it item, str string) string {
switch s[r] { switch s[r] {
default: default:
p.bug("Expected valid escape code after \\, but got %q.", s[r]) p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case ' ', '\t': case ' ', '\t':
p.panicItemf(it, "invalid escape: '\\%c'", s[r]) p.panicItemf(it, "invalid escape: '\\%c'", s[r])
return ""
case 'b': case 'b':
replaced = append(replaced, rune(0x0008)) replaced = append(replaced, rune(0x0008))
r += 1 r += 1

View File

@@ -191,6 +191,8 @@ type Linux struct {
IntelRdt *LinuxIntelRdt `json:"intelRdt,omitempty"` IntelRdt *LinuxIntelRdt `json:"intelRdt,omitempty"`
// Personality contains configuration for the Linux personality syscall // Personality contains configuration for the Linux personality syscall
Personality *LinuxPersonality `json:"personality,omitempty"` Personality *LinuxPersonality `json:"personality,omitempty"`
// TimeOffsets specifies the offset for supporting time namespaces.
TimeOffsets map[string]LinuxTimeOffset `json:"timeOffsets,omitempty"`
} }
// LinuxNamespace is the configuration for a Linux namespace // LinuxNamespace is the configuration for a Linux namespace
@@ -220,6 +222,8 @@ const (
UserNamespace LinuxNamespaceType = "user" UserNamespace LinuxNamespaceType = "user"
// CgroupNamespace for isolating cgroup hierarchies // CgroupNamespace for isolating cgroup hierarchies
CgroupNamespace LinuxNamespaceType = "cgroup" CgroupNamespace LinuxNamespaceType = "cgroup"
// TimeNamespace for isolating the clocks
TimeNamespace LinuxNamespaceType = "time"
) )
// LinuxIDMapping specifies UID/GID mappings // LinuxIDMapping specifies UID/GID mappings
@@ -232,6 +236,14 @@ type LinuxIDMapping struct {
Size uint32 `json:"size"` Size uint32 `json:"size"`
} }
// LinuxTimeOffset specifies the offset for Time Namespace
type LinuxTimeOffset struct {
// Secs is the offset of clock (in secs) in the container
Secs int64 `json:"secs,omitempty"`
// Nanosecs is the additional offset for Secs (in nanosecs)
Nanosecs uint32 `json:"nanosecs,omitempty"`
}
// POSIXRlimit type and restrictions // POSIXRlimit type and restrictions
type POSIXRlimit struct { type POSIXRlimit struct {
// Type of the rlimit to set // Type of the rlimit to set
@@ -242,12 +254,13 @@ type POSIXRlimit struct {
Soft uint64 `json:"soft"` Soft uint64 `json:"soft"`
} }
// LinuxHugepageLimit structure corresponds to limiting kernel hugepages // LinuxHugepageLimit structure corresponds to limiting kernel hugepages.
// Default to reservation limits if supported. Otherwise fallback to page fault limits.
type LinuxHugepageLimit struct { type LinuxHugepageLimit struct {
// Pagesize is the hugepage size // Pagesize is the hugepage size.
// Format: "<size><unit-prefix>B' (e.g. 64KB, 2MB, 1GB, etc.) // Format: "<size><unit-prefix>B' (e.g. 64KB, 2MB, 1GB, etc.).
Pagesize string `json:"pageSize"` Pagesize string `json:"pageSize"`
// Limit is the limit of "hugepagesize" hugetlb usage // Limit is the limit of "hugepagesize" hugetlb reservations (if supported) or usage.
Limit uint64 `json:"limit"` Limit uint64 `json:"limit"`
} }
@@ -319,6 +332,10 @@ type LinuxMemory struct {
DisableOOMKiller *bool `json:"disableOOMKiller,omitempty"` DisableOOMKiller *bool `json:"disableOOMKiller,omitempty"`
// Enables hierarchical memory accounting // Enables hierarchical memory accounting
UseHierarchy *bool `json:"useHierarchy,omitempty"` UseHierarchy *bool `json:"useHierarchy,omitempty"`
// CheckBeforeUpdate enables checking if a new memory limit is lower
// than the current usage during update, and if so, rejecting the new
// limit.
CheckBeforeUpdate *bool `json:"checkBeforeUpdate,omitempty"`
} }
// LinuxCPU for Linux cgroup 'cpu' resource management // LinuxCPU for Linux cgroup 'cpu' resource management
@@ -327,6 +344,9 @@ type LinuxCPU struct {
Shares *uint64 `json:"shares,omitempty"` Shares *uint64 `json:"shares,omitempty"`
// CPU hardcap limit (in usecs). Allowed cpu time in a given period. // CPU hardcap limit (in usecs). Allowed cpu time in a given period.
Quota *int64 `json:"quota,omitempty"` Quota *int64 `json:"quota,omitempty"`
// CPU hardcap burst limit (in usecs). Allowed accumulated cpu time additionally for burst in a
// given period.
Burst *uint64 `json:"burst,omitempty"`
// CPU period to be used for hardcapping (in usecs). // CPU period to be used for hardcapping (in usecs).
Period *uint64 `json:"period,omitempty"` Period *uint64 `json:"period,omitempty"`
// How much time realtime scheduling may use (in usecs). // How much time realtime scheduling may use (in usecs).
@@ -375,7 +395,7 @@ type LinuxResources struct {
Pids *LinuxPids `json:"pids,omitempty"` Pids *LinuxPids `json:"pids,omitempty"`
// BlockIO restriction configuration // BlockIO restriction configuration
BlockIO *LinuxBlockIO `json:"blockIO,omitempty"` BlockIO *LinuxBlockIO `json:"blockIO,omitempty"`
// Hugetlb limit (in bytes) // Hugetlb limits (in bytes). Default to reservation limits if supported.
HugepageLimits []LinuxHugepageLimit `json:"hugepageLimits,omitempty"` HugepageLimits []LinuxHugepageLimit `json:"hugepageLimits,omitempty"`
// Network restriction configuration // Network restriction configuration
Network *LinuxNetwork `json:"network,omitempty"` Network *LinuxNetwork `json:"network,omitempty"`
@@ -645,6 +665,10 @@ const (
// LinuxSeccompFlagSpecAllow can be used to disable Speculative Store // LinuxSeccompFlagSpecAllow can be used to disable Speculative Store
// Bypass mitigation. (since Linux 4.17) // Bypass mitigation. (since Linux 4.17)
LinuxSeccompFlagSpecAllow LinuxSeccompFlag = "SECCOMP_FILTER_FLAG_SPEC_ALLOW" LinuxSeccompFlagSpecAllow LinuxSeccompFlag = "SECCOMP_FILTER_FLAG_SPEC_ALLOW"
// LinuxSeccompFlagWaitKillableRecv can be used to switch to the wait
// killable semantics. (since Linux 5.19)
LinuxSeccompFlagWaitKillableRecv LinuxSeccompFlag = "SECCOMP_FILTER_FLAG_WAIT_KILLABLE_RECV"
) )
// Additional architectures permitted to be used for system calls // Additional architectures permitted to be used for system calls

View File

@@ -6,12 +6,12 @@ const (
// VersionMajor is for an API incompatible changes // VersionMajor is for an API incompatible changes
VersionMajor = 1 VersionMajor = 1
// VersionMinor is for functionality in a backwards-compatible manner // VersionMinor is for functionality in a backwards-compatible manner
VersionMinor = 0 VersionMinor = 1
// VersionPatch is for backwards-compatible bug fixes // VersionPatch is for backwards-compatible bug fixes
VersionPatch = 2 VersionPatch = 0
// VersionDev indicates development branch. Releases will be empty string. // VersionDev indicates development branch. Releases will be empty string.
VersionDev = "-dev" VersionDev = "-rc.2"
) )
// Version is the specification version that the package types support. // Version is the specification version that the package types support.

View File

@@ -1,8 +1,10 @@
package assert package assert
import ( import (
"bytes"
"fmt" "fmt"
"reflect" "reflect"
"time"
) )
type CompareType int type CompareType int
@@ -30,6 +32,9 @@ var (
float64Type = reflect.TypeOf(float64(1)) float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("") stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
bytesType = reflect.TypeOf([]byte{})
) )
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@@ -299,6 +304,47 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return compareLess, true return compareLess, true
} }
} }
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
case reflect.Slice:
{
// We only care about the []byte type.
if !canConvert(obj1Value, bytesType) {
break
}
// []byte can be compared!
bytesObj1, ok := obj1.([]byte)
if !ok {
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
}
bytesObj2, ok := obj2.([]byte)
if !ok {
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
}
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
}
} }
return compareEqual, false return compareEqual, false
@@ -310,7 +356,10 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
// assert.Greater(t, float64(2), float64(1)) // assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a") // assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs) if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
} }
// GreaterOrEqual asserts that the first element is greater than or equal to the second // GreaterOrEqual asserts that the first element is greater than or equal to the second
@@ -320,7 +369,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
// assert.GreaterOrEqual(t, "b", "a") // assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b") // assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs) if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
} }
// Less asserts that the first element is less than the second // Less asserts that the first element is less than the second
@@ -329,7 +381,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
// assert.Less(t, float64(1), float64(2)) // assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b") // assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs) if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
} }
// LessOrEqual asserts that the first element is less than or equal to the second // LessOrEqual asserts that the first element is less than or equal to the second
@@ -339,7 +394,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
// assert.LessOrEqual(t, "a", "b") // assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b") // assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs) if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
} }
// Positive asserts that the specified element is positive // Positive asserts that the specified element is positive
@@ -347,8 +405,11 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
// assert.Positive(t, 1) // assert.Positive(t, 1)
// assert.Positive(t, 1.23) // assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e)) zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs) return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
} }
// Negative asserts that the specified element is negative // Negative asserts that the specified element is negative
@@ -356,8 +417,11 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
// assert.Negative(t, -1) // assert.Negative(t, -1)
// assert.Negative(t, -1.23) // assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e)) zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs) return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
} }
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {

View File

@@ -0,0 +1,16 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@@ -0,0 +1,16 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@@ -123,6 +123,18 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...) return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
} }
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
}
// ErrorIsf asserts that at least one of the errors in err's chain matches target. // ErrorIsf asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool { func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
@@ -724,6 +736,16 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...) return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
} }
// WithinRangef asserts that a time is within a time range (inclusive).
//
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return WithinRange(t, actual, start, end, append([]interface{}{msg}, args...)...)
}
// YAMLEqf asserts that two YAML strings are equivalent. // YAMLEqf asserts that two YAML strings are equivalent.
func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool { func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {

View File

@@ -222,6 +222,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
return ErrorAsf(a.t, err, target, msg, args...) return ErrorAsf(a.t, err, target, msg, args...)
} }
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target. // ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool { func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
@@ -1437,6 +1461,26 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta
return WithinDurationf(a.t, expected, actual, delta, msg, args...) return WithinDurationf(a.t, expected, actual, delta, msg, args...)
} }
// WithinRange asserts that a time is within a time range (inclusive).
//
// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinRange(a.t, actual, start, end, msgAndArgs...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinRangef(a.t, actual, start, end, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent. // YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) bool { func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {

View File

@@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// assert.IsIncreasing(t, []float{1, 2}) // assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"}) // assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs) return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
} }
// IsNonIncreasing asserts that the collection is not increasing // IsNonIncreasing asserts that the collection is not increasing
@@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonIncreasing(t, []float{2, 1}) // assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"}) // assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs) return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
} }
// IsDecreasing asserts that the collection is decreasing // IsDecreasing asserts that the collection is decreasing
@@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
// assert.IsDecreasing(t, []float{2, 1}) // assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"}) // assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs) return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
} }
// IsNonDecreasing asserts that the collection is not decreasing // IsNonDecreasing asserts that the collection is not decreasing
@@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonDecreasing(t, []float{1, 2}) // assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"}) // assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs) return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
} }

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"math" "math"
"os" "os"
"path/filepath"
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
@@ -144,7 +145,8 @@ func CallerInfo() []string {
if len(parts) > 1 { if len(parts) > 1 {
dir := parts[len(parts)-2] dir := parts[len(parts)-2]
if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" { if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" {
callers = append(callers, fmt.Sprintf("%s:%d", file, line)) path, _ := filepath.Abs(file)
callers = append(callers, fmt.Sprintf("%s:%d", path, line))
} }
} }
@@ -563,16 +565,17 @@ func isEmpty(object interface{}) bool {
switch objValue.Kind() { switch objValue.Kind() {
// collection types are empty when they have no element // collection types are empty when they have no element
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice: case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0 return objValue.Len() == 0
// pointers are empty if nil or if the value they point to is empty // pointers are empty if nil or if the value they point to is empty
case reflect.Ptr: case reflect.Ptr:
if objValue.IsNil() { if objValue.IsNil() {
return true return true
} }
deref := objValue.Elem().Interface() deref := objValue.Elem().Interface()
return isEmpty(deref) return isEmpty(deref)
// for all other types, compare against the zero value // for all other types, compare against the zero value
// array types are empty when they match their zero-initialized state
default: default:
zero := reflect.Zero(objValue.Type()) zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface()) return reflect.DeepEqual(object, zero.Interface())
@@ -718,10 +721,14 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
// return (false, false) if impossible. // return (false, false) if impossible.
// return (true, false) if element was not found. // return (true, false) if element was not found.
// return (true, true) if element was found. // return (true, true) if element was found.
func includeElement(list interface{}, element interface{}) (ok, found bool) { func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list) listValue := reflect.ValueOf(list)
listKind := reflect.TypeOf(list).Kind() listType := reflect.TypeOf(list)
if listType == nil {
return false, false
}
listKind := listType.Kind()
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
ok = false ok = false
@@ -764,7 +771,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper() h.Helper()
} }
ok, found := includeElement(s, contains) ok, found := containsElement(s, contains)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...) return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
} }
@@ -787,7 +794,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper() h.Helper()
} }
ok, found := includeElement(s, contains) ok, found := containsElement(s, contains)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
} }
@@ -811,7 +818,6 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true // we consider nil to be equal to the nil set return true // we consider nil to be equal to the nil set
} }
subsetValue := reflect.ValueOf(subset)
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
ok = false ok = false
@@ -821,17 +827,35 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
listKind := reflect.TypeOf(list).Kind() listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind() subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice { if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
} }
if subsetKind != reflect.Array && subsetKind != reflect.Slice { if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
} }
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
if !ObjectsAreEqual(subsetElement, listElement) {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...)
}
}
return true
}
for i := 0; i < subsetValue.Len(); i++ { for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface() element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element) ok, found := containsElement(list, element)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
} }
@@ -852,10 +876,9 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
h.Helper() h.Helper()
} }
if subset == nil { if subset == nil {
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...) return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
} }
subsetValue := reflect.ValueOf(subset)
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
ok = false ok = false
@@ -865,17 +888,35 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
listKind := reflect.TypeOf(list).Kind() listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind() subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice { if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
} }
if subsetKind != reflect.Array && subsetKind != reflect.Slice { if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
} }
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
if !ObjectsAreEqual(subsetElement, listElement) {
return true
}
}
return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...)
}
for i := 0; i < subsetValue.Len(); i++ { for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface() element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element) ok, found := containsElement(list, element)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
} }
@@ -1000,27 +1041,21 @@ func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
type PanicTestFunc func() type PanicTestFunc func()
// didPanic returns true if the function passed to it panics. Otherwise, it returns false. // didPanic returns true if the function passed to it panics. Otherwise, it returns false.
func didPanic(f PanicTestFunc) (bool, interface{}, string) { func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) {
didPanic = true
didPanic := false
var message interface{}
var stack string
func() {
defer func() {
if message = recover(); message != nil {
didPanic = true
stack = string(debug.Stack())
}
}()
// call the target function
f()
defer func() {
message = recover()
if didPanic {
stack = string(debug.Stack())
}
}() }()
return didPanic, message, stack // call the target function
f()
didPanic = false
return
} }
// Panics asserts that the code inside the specified PanicTestFunc panics. // Panics asserts that the code inside the specified PanicTestFunc panics.
@@ -1111,6 +1146,27 @@ func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration,
return true return true
} }
// WithinRange asserts that a time is within a time range (inclusive).
//
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func WithinRange(t TestingT, actual, start, end time.Time, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if end.Before(start) {
return Fail(t, "Start should be before end", msgAndArgs...)
}
if actual.Before(start) {
return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is before the range", actual, start, end), msgAndArgs...)
} else if actual.After(end) {
return Fail(t, fmt.Sprintf("Time %v expected to be in time range %v to %v, but is after the range", actual, start, end), msgAndArgs...)
}
return true
}
func toFloat(x interface{}) (float64, bool) { func toFloat(x interface{}) (float64, bool) {
var xf float64 var xf float64
xok := true xok := true
@@ -1161,11 +1217,15 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
bf, bok := toFloat(actual) bf, bok := toFloat(actual)
if !aok || !bok { if !aok || !bok {
return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...) return Fail(t, "Parameters must be numerical", msgAndArgs...)
}
if math.IsNaN(af) && math.IsNaN(bf) {
return true
} }
if math.IsNaN(af) { if math.IsNaN(af) {
return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...) return Fail(t, "Expected must not be NaN", msgAndArgs...)
} }
if math.IsNaN(bf) { if math.IsNaN(bf) {
@@ -1188,7 +1248,7 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
if expected == nil || actual == nil || if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice || reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice { reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...) return Fail(t, "Parameters must be slice", msgAndArgs...)
} }
actualSlice := reflect.ValueOf(actual) actualSlice := reflect.ValueOf(actual)
@@ -1250,8 +1310,12 @@ func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, m
func calcRelativeError(expected, actual interface{}) (float64, error) { func calcRelativeError(expected, actual interface{}) (float64, error) {
af, aok := toFloat(expected) af, aok := toFloat(expected)
if !aok { bf, bok := toFloat(actual)
return 0, fmt.Errorf("expected value %q cannot be converted to float", expected) if !aok || !bok {
return 0, fmt.Errorf("Parameters must be numerical")
}
if math.IsNaN(af) && math.IsNaN(bf) {
return 0, nil
} }
if math.IsNaN(af) { if math.IsNaN(af) {
return 0, errors.New("expected value must not be NaN") return 0, errors.New("expected value must not be NaN")
@@ -1259,10 +1323,6 @@ func calcRelativeError(expected, actual interface{}) (float64, error) {
if af == 0 { if af == 0 {
return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error") return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error")
} }
bf, bok := toFloat(actual)
if !bok {
return 0, fmt.Errorf("actual value %q cannot be converted to float", actual)
}
if math.IsNaN(bf) { if math.IsNaN(bf) {
return 0, errors.New("actual value must not be NaN") return 0, errors.New("actual value must not be NaN")
} }
@@ -1298,7 +1358,7 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if expected == nil || actual == nil || if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice || reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice { reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...) return Fail(t, "Parameters must be slice", msgAndArgs...)
} }
actualSlice := reflect.ValueOf(actual) actualSlice := reflect.ValueOf(actual)
@@ -1375,6 +1435,27 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
return true return true
} }
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !Error(t, theError, msgAndArgs...) {
return false
}
actual := theError.Error()
if !strings.Contains(actual, contains) {
return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...)
}
return true
}
// matchRegexp return true if a specified regexp matches a string. // matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool { func matchRegexp(rx interface{}, str interface{}) bool {
@@ -1588,12 +1669,17 @@ func diff(expected interface{}, actual interface{}) string {
} }
var e, a string var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected) switch et {
a = spewConfig.Sdump(actual) case reflect.TypeOf(""):
} else {
e = reflect.ValueOf(expected).String() e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String() a = reflect.ValueOf(actual).String()
case reflect.TypeOf(time.Time{}):
e = spewConfigStringerEnabled.Sdump(expected)
a = spewConfigStringerEnabled.Sdump(actual)
default:
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} }
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
@@ -1625,6 +1711,14 @@ var spewConfig = spew.ConfigState{
MaxDepth: 10, MaxDepth: 10,
} }
var spewConfigStringerEnabled = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
MaxDepth: 10,
}
type tHelper interface { type tHelper interface {
Helper() Helper()
} }

View File

@@ -280,6 +280,36 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
t.FailNow() t.FailNow()
} }
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContains(t, theError, contains, msgAndArgs...) {
return
}
t.FailNow()
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContainsf(t, theError, contains, msg, args...) {
return
}
t.FailNow()
}
// ErrorIs asserts that at least one of the errors in err's chain matches target. // ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func ErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) { func ErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) {
@@ -1834,6 +1864,32 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
t.FailNow() t.FailNow()
} }
// WithinRange asserts that a time is within a time range (inclusive).
//
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func WithinRange(t TestingT, actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.WithinRange(t, actual, start, end, msgAndArgs...) {
return
}
t.FailNow()
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.WithinRangef(t, actual, start, end, msg, args...) {
return
}
t.FailNow()
}
// YAMLEq asserts that two YAML strings are equivalent. // YAMLEq asserts that two YAML strings are equivalent.
func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) { func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {

View File

@@ -223,6 +223,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
ErrorAsf(a.t, err, target, msg, args...) ErrorAsf(a.t, err, target, msg, args...)
} }
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target. // ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is. // This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) { func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) {
@@ -1438,6 +1462,26 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta
WithinDurationf(a.t, expected, actual, delta, msg, args...) WithinDurationf(a.t, expected, actual, delta, msg, args...)
} }
// WithinRange asserts that a time is within a time range (inclusive).
//
// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinRange(a.t, actual, start, end, msgAndArgs...)
}
// WithinRangef asserts that a time is within a time range (inclusive).
//
// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
WithinRangef(a.t, actual, start, end, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent. // YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) { func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {

31
vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s generated vendored Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
// +build darwin freebsd netbsd openbsd
// +build gc
#include "textflag.h"
//
// System call support for ppc64, BSD
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package unix package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build gccgo && !aix //go:build gccgo && !aix && !hurd
// +build gccgo,!aix // +build gccgo,!aix,!hurd
package unix package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build gccgo //go:build gccgo && !aix && !hurd
// +build !aix // +build gccgo,!aix,!hurd
#include <errno.h> #include <errno.h>
#include <stdint.h> #include <stdint.h>

70
vendor/golang.org/x/sys/unix/ioctl_signed.go generated vendored Normal file
View File

@@ -0,0 +1,70 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || solaris
// +build aix solaris
package unix
import (
"unsafe"
)
// ioctl itself should not be exposed directly, but additional get/set
// functions for specific types are permissible.
// IoctlSetInt performs an ioctl operation which sets an integer value
// on fd, using the specified request number.
func IoctlSetInt(fd int, req int, value int) error {
return ioctl(fd, req, uintptr(value))
}
// IoctlSetPointerInt performs an ioctl operation which sets an
// integer value on fd, using the specified request number. The ioctl
// argument is called with a pointer to the integer value, rather than
// passing the integer value directly.
func IoctlSetPointerInt(fd int, req int, value int) error {
v := int32(value)
return ioctlPtr(fd, req, unsafe.Pointer(&v))
}
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
//
// To change fd's window size, the req argument should be TIOCSWINSZ.
func IoctlSetWinsize(fd int, req int, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlSetTermios performs an ioctl on fd with a *Termios.
//
// The req value will usually be TCSETA or TIOCSETA.
func IoctlSetTermios(fd int, req int, value *Termios) error {
// TODO: if we get the chance, remove the req parameter.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlGetInt performs an ioctl operation which gets an integer value
// from fd, using the specified request number.
//
// A few ioctl requests use the return value as an output parameter;
// for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req int) (int, error) {
var value int
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetWinsize(fd int, req int) (*Winsize, error) {
var value Winsize
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
func IoctlGetTermios(fd int, req int) (*Termios, error) {
var value Termios
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}

View File

@@ -2,13 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris //go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris // +build darwin dragonfly freebsd hurd linux netbsd openbsd
package unix package unix
import ( import (
"runtime"
"unsafe" "unsafe"
) )
@@ -27,7 +26,7 @@ func IoctlSetInt(fd int, req uint, value int) error {
// passing the integer value directly. // passing the integer value directly.
func IoctlSetPointerInt(fd int, req uint, value int) error { func IoctlSetPointerInt(fd int, req uint, value int) error {
v := int32(value) v := int32(value)
return ioctl(fd, req, uintptr(unsafe.Pointer(&v))) return ioctlPtr(fd, req, unsafe.Pointer(&v))
} }
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument. // IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
@@ -36,9 +35,7 @@ func IoctlSetPointerInt(fd int, req uint, value int) error {
func IoctlSetWinsize(fd int, req uint, value *Winsize) error { func IoctlSetWinsize(fd int, req uint, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and // TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ. // hardcode TIOCSWINSZ.
err := ioctl(fd, req, uintptr(unsafe.Pointer(value))) return ioctlPtr(fd, req, unsafe.Pointer(value))
runtime.KeepAlive(value)
return err
} }
// IoctlSetTermios performs an ioctl on fd with a *Termios. // IoctlSetTermios performs an ioctl on fd with a *Termios.
@@ -46,9 +43,7 @@ func IoctlSetWinsize(fd int, req uint, value *Winsize) error {
// The req value will usually be TCSETA or TIOCSETA. // The req value will usually be TCSETA or TIOCSETA.
func IoctlSetTermios(fd int, req uint, value *Termios) error { func IoctlSetTermios(fd int, req uint, value *Termios) error {
// TODO: if we get the chance, remove the req parameter. // TODO: if we get the chance, remove the req parameter.
err := ioctl(fd, req, uintptr(unsafe.Pointer(value))) return ioctlPtr(fd, req, unsafe.Pointer(value))
runtime.KeepAlive(value)
return err
} }
// IoctlGetInt performs an ioctl operation which gets an integer value // IoctlGetInt performs an ioctl operation which gets an integer value
@@ -58,18 +53,18 @@ func IoctlSetTermios(fd int, req uint, value *Termios) error {
// for those, IoctlRetInt should be used instead of this function. // for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req uint) (int, error) { func IoctlGetInt(fd int, req uint) (int, error) {
var value int var value int
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value))) err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err return value, err
} }
func IoctlGetWinsize(fd int, req uint) (*Winsize, error) { func IoctlGetWinsize(fd int, req uint) (*Winsize, error) {
var value Winsize var value Winsize
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value))) err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err return &value, err
} }
func IoctlGetTermios(fd int, req uint) (*Termios, error) { func IoctlGetTermios(fd int, req uint) (*Termios, error) {
var value Termios var value Termios
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value))) err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err return &value, err
} }

View File

@@ -17,25 +17,23 @@ import (
// IoctlSetInt performs an ioctl operation which sets an integer value // IoctlSetInt performs an ioctl operation which sets an integer value
// on fd, using the specified request number. // on fd, using the specified request number.
func IoctlSetInt(fd int, req uint, value int) error { func IoctlSetInt(fd int, req int, value int) error {
return ioctl(fd, req, uintptr(value)) return ioctl(fd, req, uintptr(value))
} }
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument. // IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
// //
// To change fd's window size, the req argument should be TIOCSWINSZ. // To change fd's window size, the req argument should be TIOCSWINSZ.
func IoctlSetWinsize(fd int, req uint, value *Winsize) error { func IoctlSetWinsize(fd int, req int, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and // TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ. // hardcode TIOCSWINSZ.
err := ioctl(fd, req, uintptr(unsafe.Pointer(value))) return ioctlPtr(fd, req, unsafe.Pointer(value))
runtime.KeepAlive(value)
return err
} }
// IoctlSetTermios performs an ioctl on fd with a *Termios. // IoctlSetTermios performs an ioctl on fd with a *Termios.
// //
// The req value is expected to be TCSETS, TCSETSW, or TCSETSF // The req value is expected to be TCSETS, TCSETSW, or TCSETSF
func IoctlSetTermios(fd int, req uint, value *Termios) error { func IoctlSetTermios(fd int, req int, value *Termios) error {
if (req != TCSETS) && (req != TCSETSW) && (req != TCSETSF) { if (req != TCSETS) && (req != TCSETSW) && (req != TCSETSF) {
return ENOSYS return ENOSYS
} }
@@ -49,22 +47,22 @@ func IoctlSetTermios(fd int, req uint, value *Termios) error {
// //
// A few ioctl requests use the return value as an output parameter; // A few ioctl requests use the return value as an output parameter;
// for those, IoctlRetInt should be used instead of this function. // for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req uint) (int, error) { func IoctlGetInt(fd int, req int) (int, error) {
var value int var value int
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value))) err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err return value, err
} }
func IoctlGetWinsize(fd int, req uint) (*Winsize, error) { func IoctlGetWinsize(fd int, req int) (*Winsize, error) {
var value Winsize var value Winsize
err := ioctl(fd, req, uintptr(unsafe.Pointer(&value))) err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err return &value, err
} }
// IoctlGetTermios performs an ioctl on fd with a *Termios. // IoctlGetTermios performs an ioctl on fd with a *Termios.
// //
// The req value is expected to be TCGETS // The req value is expected to be TCGETS
func IoctlGetTermios(fd int, req uint) (*Termios, error) { func IoctlGetTermios(fd int, req int) (*Termios, error) {
var value Termios var value Termios
if req != TCGETS { if req != TCGETS {
return &value, ENOSYS return &value, ENOSYS

View File

@@ -174,10 +174,28 @@ openbsd_arm64)
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char" mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;; ;;
openbsd_mips64) openbsd_mips64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64" mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd" mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_ppc64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_riscv64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go" mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall # Let the type of C char be signed for making the bare syscall
# API consistent across platforms. # API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char" mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"

View File

@@ -66,6 +66,7 @@ includes_Darwin='
#include <sys/ptrace.h> #include <sys/ptrace.h>
#include <sys/select.h> #include <sys/select.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h> #include <sys/un.h>
#include <sys/sockio.h> #include <sys/sockio.h>
#include <sys/sys_domain.h> #include <sys/sys_domain.h>
@@ -521,6 +522,7 @@ ccflags="$@"
$2 ~ /^NFC_(GENL|PROTO|COMM|RF|SE|DIRECTION|LLCP|SOCKPROTO)_/ || $2 ~ /^NFC_(GENL|PROTO|COMM|RF|SE|DIRECTION|LLCP|SOCKPROTO)_/ ||
$2 ~ /^NFC_.*_(MAX)?SIZE$/ || $2 ~ /^NFC_.*_(MAX)?SIZE$/ ||
$2 ~ /^RAW_PAYLOAD_/ || $2 ~ /^RAW_PAYLOAD_/ ||
$2 ~ /^[US]F_/ ||
$2 ~ /^TP_STATUS_/ || $2 ~ /^TP_STATUS_/ ||
$2 ~ /^FALLOC_/ || $2 ~ /^FALLOC_/ ||
$2 ~ /^ICMPV?6?_(FILTER|SEC)/ || $2 ~ /^ICMPV?6?_(FILTER|SEC)/ ||
@@ -642,7 +644,7 @@ errors=$(
signals=$( signals=$(
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags | echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' | awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' | grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort sort
) )
@@ -652,7 +654,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags |
sort >_error.grep sort >_error.grep
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags | echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' | awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' | grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort >_signal.grep sort >_signal.grep
echo '// mkerrors.sh' "$@" echo '// mkerrors.sh' "$@"

View File

@@ -7,6 +7,12 @@
package unix package unix
import "unsafe"
func ptrace(request int, pid int, addr uintptr, data uintptr) error { func ptrace(request int, pid int, addr uintptr, data uintptr) error {
return ptrace1(request, pid, addr, data) return ptrace1(request, pid, addr, data)
} }
func ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) error {
return ptrace1Ptr(request, pid, addr, data)
}

View File

@@ -7,6 +7,12 @@
package unix package unix
import "unsafe"
func ptrace(request int, pid int, addr uintptr, data uintptr) (err error) { func ptrace(request int, pid int, addr uintptr, data uintptr) (err error) {
return ENOTSUP return ENOTSUP
} }
func ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) (err error) {
return ENOTSUP
}

View File

@@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) {
return msgs, nil return msgs, nil
} }
// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header,
// message data (a slice of b), and the remainder of b after that single message.
// When there are no remaining messages, len(remainder) == 0.
func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) {
h, dbuf, err := socketControlMessageHeaderAndData(b)
if err != nil {
return Cmsghdr{}, nil, nil, err
}
if i := cmsgAlignOf(int(h.Len)); i < len(b) {
remainder = b[i:]
}
return *h, dbuf, remainder, nil
}
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) { func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
h := (*Cmsghdr)(unsafe.Pointer(&b[0])) h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) { if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {

View File

@@ -292,9 +292,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) {
break break
} }
} }
sa.Name = string(unsafe.Slice((*byte)(unsafe.Pointer(&pp.Path[0])), n))
bytes := (*[len(pp.Path)]byte)(unsafe.Pointer(&pp.Path[0]))[0:n]
sa.Name = string(bytes)
return sa, nil return sa, nil
case AF_INET: case AF_INET:
@@ -410,7 +408,8 @@ func (w WaitStatus) CoreDump() bool { return w&0x80 == 0x80 }
func (w WaitStatus) TrapCause() int { return -1 } func (w WaitStatus) TrapCause() int { return -1 }
//sys ioctl(fd int, req uint, arg uintptr) (err error) //sys ioctl(fd int, req int, arg uintptr) (err error)
//sys ioctlPtr(fd int, req int, arg unsafe.Pointer) (err error) = ioctl
// fcntl must never be called with cmd=F_DUP2FD because it doesn't work on AIX // fcntl must never be called with cmd=F_DUP2FD because it doesn't work on AIX
// There is no way to create a custom fcntl and to keep //sys fcntl easily, // There is no way to create a custom fcntl and to keep //sys fcntl easily,

View File

@@ -8,7 +8,6 @@
package unix package unix
//sysnb Getrlimit(resource int, rlim *Rlimit) (err error) = getrlimit64 //sysnb Getrlimit(resource int, rlim *Rlimit) (err error) = getrlimit64
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error) = setrlimit64
//sys Seek(fd int, offset int64, whence int) (off int64, err error) = lseek64 //sys Seek(fd int, offset int64, whence int) (off int64, err error) = lseek64
//sys mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error) //sys mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error)

View File

@@ -8,7 +8,6 @@
package unix package unix
//sysnb Getrlimit(resource int, rlim *Rlimit) (err error) //sysnb Getrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sys Seek(fd int, offset int64, whence int) (off int64, err error) = lseek //sys Seek(fd int, offset int64, whence int) (off int64, err error) = lseek
//sys mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error) = mmap64 //sys mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error) = mmap64

View File

@@ -245,8 +245,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) {
break break
} }
} }
bytes := (*[len(pp.Path)]byte)(unsafe.Pointer(&pp.Path[0]))[0:n] sa.Name = string(unsafe.Slice((*byte)(unsafe.Pointer(&pp.Path[0])), n))
sa.Name = string(bytes)
return sa, nil return sa, nil
case AF_INET: case AF_INET:

View File

@@ -14,7 +14,6 @@ package unix
import ( import (
"fmt" "fmt"
"runtime"
"syscall" "syscall"
"unsafe" "unsafe"
) )
@@ -230,6 +229,7 @@ func direntNamlen(buf []byte) (uint64, bool) {
func PtraceAttach(pid int) (err error) { return ptrace(PT_ATTACH, pid, 0, 0) } func PtraceAttach(pid int) (err error) { return ptrace(PT_ATTACH, pid, 0, 0) }
func PtraceDetach(pid int) (err error) { return ptrace(PT_DETACH, pid, 0, 0) } func PtraceDetach(pid int) (err error) { return ptrace(PT_DETACH, pid, 0, 0) }
func PtraceDenyAttach() (err error) { return ptrace(PT_DENY_ATTACH, 0, 0, 0) }
//sysnb pipe(p *[2]int32) (err error) //sysnb pipe(p *[2]int32) (err error)
@@ -375,11 +375,10 @@ func Flistxattr(fd int, dest []byte) (sz int, err error) {
func Kill(pid int, signum syscall.Signal) (err error) { return kill(pid, int(signum), 1) } func Kill(pid int, signum syscall.Signal) (err error) { return kill(pid, int(signum), 1) }
//sys ioctl(fd int, req uint, arg uintptr) (err error) //sys ioctl(fd int, req uint, arg uintptr) (err error)
//sys ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error) = SYS_IOCTL
func IoctlCtlInfo(fd int, ctlInfo *CtlInfo) error { func IoctlCtlInfo(fd int, ctlInfo *CtlInfo) error {
err := ioctl(fd, CTLIOCGINFO, uintptr(unsafe.Pointer(ctlInfo))) return ioctlPtr(fd, CTLIOCGINFO, unsafe.Pointer(ctlInfo))
runtime.KeepAlive(ctlInfo)
return err
} }
// IfreqMTU is struct ifreq used to get or set a network device's MTU. // IfreqMTU is struct ifreq used to get or set a network device's MTU.
@@ -393,16 +392,14 @@ type IfreqMTU struct {
func IoctlGetIfreqMTU(fd int, ifname string) (*IfreqMTU, error) { func IoctlGetIfreqMTU(fd int, ifname string) (*IfreqMTU, error) {
var ifreq IfreqMTU var ifreq IfreqMTU
copy(ifreq.Name[:], ifname) copy(ifreq.Name[:], ifname)
err := ioctl(fd, SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq))) err := ioctlPtr(fd, SIOCGIFMTU, unsafe.Pointer(&ifreq))
return &ifreq, err return &ifreq, err
} }
// IoctlSetIfreqMTU performs the SIOCSIFMTU ioctl operation on fd to set the MTU // IoctlSetIfreqMTU performs the SIOCSIFMTU ioctl operation on fd to set the MTU
// of the network device specified by ifreq.Name. // of the network device specified by ifreq.Name.
func IoctlSetIfreqMTU(fd int, ifreq *IfreqMTU) error { func IoctlSetIfreqMTU(fd int, ifreq *IfreqMTU) error {
err := ioctl(fd, SIOCSIFMTU, uintptr(unsafe.Pointer(ifreq))) return ioctlPtr(fd, SIOCSIFMTU, unsafe.Pointer(ifreq))
runtime.KeepAlive(ifreq)
return err
} }
//sys sysctl(mib []_C_int, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) = SYS_SYSCTL //sys sysctl(mib []_C_int, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) = SYS_SYSCTL
@@ -616,6 +613,7 @@ func SysctlKinfoProcSlice(name string, args ...int) ([]KinfoProc, error) {
//sys Rmdir(path string) (err error) //sys Rmdir(path string) (err error)
//sys Seek(fd int, offset int64, whence int) (newoffset int64, err error) = SYS_LSEEK //sys Seek(fd int, offset int64, whence int) (newoffset int64, err error) = SYS_LSEEK
//sys Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error) //sys Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error)
//sys Setattrlist(path string, attrlist *Attrlist, attrBuf []byte, options int) (err error)
//sys Setegid(egid int) (err error) //sys Setegid(egid int) (err error)
//sysnb Seteuid(euid int) (err error) //sysnb Seteuid(euid int) (err error)
//sysnb Setgid(gid int) (err error) //sysnb Setgid(gid int) (err error)
@@ -625,7 +623,6 @@ func SysctlKinfoProcSlice(name string, args ...int) ([]KinfoProc, error) {
//sys Setprivexec(flag int) (err error) //sys Setprivexec(flag int) (err error)
//sysnb Setregid(rgid int, egid int) (err error) //sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setreuid(ruid int, euid int) (err error) //sysnb Setreuid(ruid int, euid int) (err error)
//sysnb Setrlimit(which int, lim *Rlimit) (err error)
//sysnb Setsid() (pid int, err error) //sysnb Setsid() (pid int, err error)
//sysnb Settimeofday(tp *Timeval) (err error) //sysnb Settimeofday(tp *Timeval) (err error)
//sysnb Setuid(uid int) (err error) //sysnb Setuid(uid int) (err error)
@@ -679,7 +676,6 @@ func SysctlKinfoProcSlice(name string, args ...int) ([]KinfoProc, error) {
// Kqueue_from_portset_np // Kqueue_from_portset_np
// Kqueue_portset // Kqueue_portset
// Getattrlist // Getattrlist
// Setattrlist
// Getdirentriesattr // Getdirentriesattr
// Searchfs // Searchfs
// Delete // Delete

Some files were not shown because too many files have changed in this diff Show More