Merge branch 'golangci-lint' into 'main'

Use golanglint-ci to check code in NVIDIA Container Toolkit

See merge request nvidia/container-toolkit/container-toolkit!474
This commit is contained in:
Evan Lezar 2023-10-24 18:59:46 +00:00
commit 1b1aae9c4a
71 changed files with 239 additions and 213 deletions

View File

@ -25,43 +25,15 @@ build-dev-image:
.requires-build-image: .requires-build-image:
image: "${BUILDIMAGE}" image: "${BUILDIMAGE}"
needs:
- job: build-dev-image
.go-check: check:
extends: extends:
- .requires-build-image - .requires-build-image
stage: go-checks stage: go-checks
fmt:
extends:
- .go-check
script: script:
- make assert-fmt - make check
vet:
extends:
- .go-check
script:
- make vet
lint:
extends:
- .go-check
script:
- make lint
allow_failure: true
ineffassign:
extends:
- .go-check
script:
- make ineffassign
allow_failure: true
misspell:
extends:
- .go-check
script:
- make misspell
go-build: go-build:
extends: extends:

29
.golangci.yml Normal file
View File

@ -0,0 +1,29 @@
run:
deadline: 10m
linters:
enable:
- contextcheck
- gocritic
- gofmt
- goimports
- gosec
- gosimple
- govet
- ineffassign
- misspell
- staticcheck
- unconvert
issues:
exclude-rules:
# Exclude the gocritic dupSubExpr issue for cgo files.
- path: internal/dxcore/dxcore.go
linters:
- gocritic
text: dupSubExpr
# Exclude the checks for usage of returns to config.Delete(Path) in the crio and containerd config packages.
- path: pkg/config/engine/
linters:
- errcheck
text: config.Delete

View File

@ -38,7 +38,7 @@ EXAMPLE_TARGETS := $(patsubst %,example-%, $(EXAMPLES))
CMDS := $(patsubst ./cmd/%/,%,$(sort $(dir $(wildcard ./cmd/*/)))) CMDS := $(patsubst ./cmd/%/,%,$(sort $(dir $(wildcard ./cmd/*/))))
CMD_TARGETS := $(patsubst %,cmd-%, $(CMDS)) CMD_TARGETS := $(patsubst %,cmd-%, $(CMDS))
CHECK_TARGETS := assert-fmt vet lint ineffassign misspell CHECK_TARGETS := golangci-lint
MAKE_TARGETS := binaries build check fmt lint-internal test examples cmds coverage generate licenses $(CHECK_TARGETS) MAKE_TARGETS := binaries build check fmt lint-internal test examples cmds coverage generate licenses $(CHECK_TARGETS)
TARGETS := $(MAKE_TARGETS) $(EXAMPLE_TARGETS) $(CMD_TARGETS) TARGETS := $(MAKE_TARGETS) $(EXAMPLE_TARGETS) $(CMD_TARGETS)
@ -78,30 +78,8 @@ fmt:
go list -f '{{.Dir}}' $(MODULE)/... \ go list -f '{{.Dir}}' $(MODULE)/... \
| xargs gofmt -s -l -w | xargs gofmt -s -l -w
assert-fmt: golangci-lint:
go list -f '{{.Dir}}' $(MODULE)/... \ golangci-lint run ./...
| xargs gofmt -s -l > fmt.out
@if [ -s fmt.out ]; then \
echo "\nERROR: The following files are not formatted:\n"; \
cat fmt.out; \
rm fmt.out; \
exit 1; \
else \
rm fmt.out; \
fi
ineffassign:
ineffassign $(MODULE)/...
lint:
# We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder.
go list -f '{{.Dir}}' $(MODULE)/... | xargs golint -set_exit_status
misspell:
misspell $(MODULE)/...
vet:
go vet $(MODULE)/...
licenses: licenses:
go-licenses csv $(MODULE)/... go-licenses csv $(MODULE)/...
@ -141,6 +119,7 @@ $(DOCKER_TARGETS): docker-%: .build-image
$(DOCKER) run \ $(DOCKER) run \
--rm \ --rm \
-e GOCACHE=/tmp/.cache \ -e GOCACHE=/tmp/.cache \
-e GOLANGCI_LINT_CACHE=/tmp/.cache \
-v $(PWD):$(PWD) \ -v $(PWD):$(PWD) \
-w $(PWD) \ -w $(PWD) \
--user $$(id -u):$$(id -g) \ --user $$(id -u):$$(id -g) \
@ -154,6 +133,7 @@ PHONY: .shell
--rm \ --rm \
-ti \ -ti \
-e GOCACHE=/tmp/.cache \ -e GOCACHE=/tmp/.cache \
-e GOLANGCI_LINT_CACHE=/tmp/.cache \
-v $(PWD):$(PWD) \ -v $(PWD):$(PWD) \
-w $(PWD) \ -w $(PWD) \
--user $$(id -u):$$(id -g) \ --user $$(id -u):$$(id -g) \

View File

@ -1015,14 +1015,14 @@ func TestGetDriverCapabilities(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
var capabilites string var capabilities string
c := HookConfig{ c := HookConfig{
SupportedDriverCapabilities: tc.supportedCapabilities, SupportedDriverCapabilities: tc.supportedCapabilities,
} }
getDriverCapabilities := func() { getDriverCapabilities := func() {
capabilites = c.getDriverCapabilities(tc.env, tc.legacyImage).String() capabilities = c.getDriverCapabilities(tc.env, tc.legacyImage).String()
} }
if tc.expectedPanic { if tc.expectedPanic {
@ -1031,7 +1031,7 @@ func TestGetDriverCapabilities(t *testing.T) {
} }
getDriverCapabilities() getDriverCapabilities()
require.EqualValues(t, tc.expectedCapabilities, capabilites) require.EqualValues(t, tc.expectedCapabilities, capabilities)
}) })
} }
} }

View File

@ -17,8 +17,6 @@ const (
driverPath = "/run/nvidia/driver" driverPath = "/run/nvidia/driver"
) )
var defaultPaths = [...]string{}
// HookConfig : options for the nvidia-container-runtime-hook. // HookConfig : options for the nvidia-container-runtime-hook.
type HookConfig config.Config type HookConfig config.Config

View File

@ -142,6 +142,7 @@ func doPrestart() {
args = append(args, rootfs) args = append(args, rootfs)
env := append(os.Environ(), cli.Environment...) env := append(os.Environ(), cli.Environment...)
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection?
err = syscall.Exec(args[0], args, env) err = syscall.Exec(args[0], args, env)
log.Panicln("exec failed:", err) log.Panicln("exec failed:", err)
} }

View File

@ -3,7 +3,7 @@ package main
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"log" "log"
"os" "os"
"os/exec" "os/exec"
@ -86,6 +86,7 @@ func TestBadInput(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle") cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle")
t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " ")) t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " "))
err = cmdCreate.Run() err = cmdCreate.Run()
@ -103,6 +104,7 @@ func TestGoodInput(t *testing.T) {
t.Fatalf("error generating runtime spec: %v", err) t.Fatalf("error generating runtime spec: %v", err)
} }
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmdRun := exec.Command(nvidiaRuntime, "run", "--bundle", cfg.bundlePath(), "testcontainer") cmdRun := exec.Command(nvidiaRuntime, "run", "--bundle", cfg.bundlePath(), "testcontainer")
t.Logf("executing: %s\n", strings.Join(cmdRun.Args, " ")) t.Logf("executing: %s\n", strings.Join(cmdRun.Args, " "))
output, err := cmdRun.CombinedOutput() output, err := cmdRun.CombinedOutput()
@ -113,6 +115,7 @@ func TestGoodInput(t *testing.T) {
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
require.Empty(t, spec.Hooks, "there should be no hooks in config.json") require.Empty(t, spec.Hooks, "there should be no hooks in config.json")
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer") cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer")
t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " ")) t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " "))
err = cmdCreate.Run() err = cmdCreate.Run()
@ -158,6 +161,7 @@ func TestDuplicateHook(t *testing.T) {
} }
// Test how runtime handles already existing prestart hook in config.json // Test how runtime handles already existing prestart hook in config.json
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer") cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer")
t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " ")) t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " "))
output, err := cmdCreate.CombinedOutput() output, err := cmdCreate.CombinedOutput()
@ -188,15 +192,16 @@ func (c testConfig) getRuntimeSpec() (specs.Spec, error) {
} }
defer jsonFile.Close() defer jsonFile.Close()
jsonContent, err := ioutil.ReadAll(jsonFile) jsonContent, err := io.ReadAll(jsonFile)
if err != nil { switch {
case err != nil:
return spec, err return spec, err
} else if json.Valid(jsonContent) { case json.Valid(jsonContent):
err = json.Unmarshal(jsonContent, &spec) err = json.Unmarshal(jsonContent, &spec)
if err != nil { if err != nil {
return spec, err return spec, err
} }
} else { default:
err = json.NewDecoder(bytes.NewReader(jsonContent)).Decode(&spec) err = json.NewDecoder(bytes.NewReader(jsonContent)).Decode(&spec)
if err != nil { if err != nil {
return spec, err return spec, err
@ -226,6 +231,7 @@ func (c testConfig) generateNewRuntimeSpec() error {
return err return err
} }
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmd := exec.Command("cp", c.unmodifiedSpecFile(), c.specFilePath()) cmd := exec.Command("cp", c.unmodifiedSpecFile(), c.specFilePath())
err = cmd.Run() err = cmd.Run()
if err != nil { if err != nil {

View File

@ -28,7 +28,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
@ -238,7 +238,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
nvcdi.WithDriverRoot(opts.driverRoot), nvcdi.WithDriverRoot(opts.driverRoot),
nvcdi.WithNVIDIACTKPath(opts.nvidiaCTKPath), nvcdi.WithNVIDIACTKPath(opts.nvidiaCTKPath),
nvcdi.WithDeviceNamer(deviceNamer), nvcdi.WithDeviceNamer(deviceNamer),
nvcdi.WithMode(string(opts.mode)), nvcdi.WithMode(opts.mode),
nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths.Value()), nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths.Value()),
nvcdi.WithCSVFiles(opts.csv.files.Value()), nvcdi.WithCSVFiles(opts.csv.files.Value()),
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()), nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()),

View File

@ -28,11 +28,6 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
type loadSaver interface {
Load() (spec.Interface, error)
Save(spec.Interface) error
}
type command struct { type command struct {
logger logger.Interface logger logger.Interface
} }

View File

@ -119,10 +119,10 @@ func run(c *cli.Context, opts *options) error {
} }
defer output.Close() defer output.Close()
if err != nil { if _, err := cfgToml.Save(output); err != nil {
return err return fmt.Errorf("failed to save config: %v", err)
} }
cfgToml.Save(output)
return nil return nil
} }
@ -146,8 +146,7 @@ func (c *configToml) setFlagToKeyValue(setFlag string) (string, interface{}, err
if v == nil { if v == nil {
return key, nil, errInvalidConfigOption return key, nil, errInvalidConfigOption
} }
switch v.(type) { if _, ok := v.(bool); ok {
case bool:
if len(setParts) == 1 { if len(setParts) == 1 {
return key, true, nil return key, true, nil
} }

View File

@ -85,8 +85,7 @@ func (m command) run(c *cli.Context, opts *flags.Options) error {
} }
defer output.Close() defer output.Close()
_, err = cfgToml.Save(output) if _, err = cfgToml.Save(output); err != nil {
if err != nil {
return fmt.Errorf("failed to write output: %v", err) return fmt.Errorf("failed to write output: %v", err)
} }

View File

@ -66,7 +66,7 @@ func (m command) build() *cli.Command {
c.Flags = []cli.Flag{ c.Flags = []cli.Flag{
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "path", Name: "path",
Usage: "Specifiy a path to apply the specified mode to", Usage: "Specify a path to apply the specified mode to",
Destination: &cfg.paths, Destination: &cfg.paths,
}, },
&cli.StringFlag{ &cli.StringFlag{
@ -127,6 +127,7 @@ func (m command) run(c *cli.Context, cfg *config) error {
args := append([]string{filepath.Base(chmodPath), cfg.mode}, paths...) args := append([]string{filepath.Base(chmodPath), cfg.mode}, paths...)
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
return syscall.Exec(chmodPath, args, nil) return syscall.Exec(chmodPath, args, nil)
} }

View File

@ -56,7 +56,7 @@ func (m command) build() *cli.Command {
// Create the '' command // Create the '' command
c := cli.Command{ c := cli.Command{
Name: "create-symlinks", Name: "create-symlinks",
Usage: "A hook to create symlinks in the container. This can be used to proces CSV mount specs", Usage: "A hook to create symlinks in the container. This can be used to process CSV mount specs",
Action: func(c *cli.Context) error { Action: func(c *cli.Context) error {
return m.run(c, &cfg) return m.run(c, &cfg)
}, },

View File

@ -60,7 +60,7 @@ func (m command) build() *cli.Command {
c.Flags = []cli.Flag{ c.Flags = []cli.Flag{
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "folder", Name: "folder",
Usage: "Specifiy a folder to add to /etc/ld.so.conf before updating the ld cache", Usage: "Specify a folder to add to /etc/ld.so.conf before updating the ld cache",
Destination: &cfg.folders, Destination: &cfg.folders,
}, },
&cli.StringFlag{ &cli.StringFlag{
@ -100,6 +100,7 @@ func (m command) run(c *cli.Context, cfg *config) error {
args = append(args, "-r", containerRoot) args = append(args, "-r", containerRoot)
} }
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
return syscall.Exec(args[0], args, nil) return syscall.Exec(args[0], args, nil)
} }

View File

@ -63,20 +63,13 @@ func (m existing) DeviceNodes() ([]deviceNode, error) {
if m.nodeIsBlocked(d) { if m.nodeIsBlocked(d) {
continue continue
} }
var stat unix.Stat_t var stat unix.Stat_t
err := unix.Stat(d, &stat) err := unix.Stat(d, &stat)
if err != nil { if err != nil {
m.logger.Warningf("Could not stat device: %v", err) m.logger.Warningf("Could not stat device: %v", err)
continue continue
} }
deviceNode := deviceNode{ deviceNodes = append(deviceNodes, newDeviceNode(d, stat))
path: d,
major: unix.Major(uint64(stat.Rdev)),
minor: unix.Minor(uint64(stat.Rdev)),
}
deviceNodes = append(deviceNodes, deviceNode)
} }
return deviceNodes, nil return deviceNodes, nil

View File

@ -0,0 +1,28 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package devchar
import "golang.org/x/sys/unix"
func newDeviceNode(d string, stat unix.Stat_t) deviceNode {
deviceNode := deviceNode{
path: d,
major: unix.Major(stat.Rdev),
minor: unix.Minor(stat.Rdev),
}
return deviceNode
}

View File

@ -0,0 +1,30 @@
//go:build !linux
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package devchar
import "golang.org/x/sys/unix"
func newDeviceNode(d string, stat unix.Stat_t) deviceNode {
deviceNode := deviceNode{
path: d,
major: unix.Major(uint64(stat.Rdev)),
minor: unix.Minor(uint64(stat.Rdev)),
}
return deviceNode
}

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
ARG GOLANG_VERSION=x.x.x ARG GOLANG_VERSION=x.x.x
ARG GOLANGCI_LINT_VERSION=v1.54.1
FROM golang:${GOLANG_VERSION} FROM golang:${GOLANG_VERSION}
RUN go install golang.org/x/lint/golint@6edffad5e6160f5949cdefc81710b2706fbcd4f6 RUN go install golang.org/x/lint/golint@6edffad5e6160f5949cdefc81710b2706fbcd4f6
@ -19,3 +20,4 @@ RUN go install github.com/matryer/moq@latest
RUN go install github.com/gordonklaus/ineffassign@d2c82e48359b033cde9cf1307f6d5550b8d61321 RUN go install github.com/gordonklaus/ineffassign@d2c82e48359b033cde9cf1307f6d5550b8d61321
RUN go install github.com/client9/misspell/cmd/misspell@latest RUN go install github.com/client9/misspell/cmd/misspell@latest
RUN go install github.com/google/go-licenses@latest RUN go install github.com/google/go-licenses@latest
RUN curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin ${GOLANGCI_LINT_VERSION}

View File

@ -35,7 +35,7 @@ func TestGetConfigWithCustomConfig(t *testing.T) {
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"") contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")
require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766)) require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
require.NoError(t, os.WriteFile(filename, contents, 0766)) require.NoError(t, os.WriteFile(filename, contents, 0600))
cfg, err := GetConfig() cfg, err := GetConfig()
require.NoError(t, err) require.NoError(t, err)

View File

@ -73,7 +73,7 @@ func (c DriverCapabilities) Has(capability DriverCapability) bool {
return c[capability] return c[capability]
} }
// Any checks whether any of the specified capabilites are set // Any checks whether any of the specified capabilities are set
func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool { func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool {
if c.IsAll() { if c.IsAll() {
return true return true

View File

@ -139,12 +139,12 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
func (i CUDA) GetDriverCapabilities() DriverCapabilities { func (i CUDA) GetDriverCapabilities() DriverCapabilities {
env := i[envNVDriverCapabilities] env := i[envNVDriverCapabilities]
capabilites := make(DriverCapabilities) capabilities := make(DriverCapabilities)
for _, c := range strings.Split(env, ",") { for _, c := range strings.Split(env, ",") {
capabilites[DriverCapability(c)] = true capabilities[DriverCapability(c)] = true
} }
return capabilites return capabilities
} }
func (i CUDA) legacyVersion() (string, error) { func (i CUDA) legacyVersion() (string, error) {

View File

@ -154,10 +154,7 @@ func (t Toml) contents() ([]byte, error) {
// format fixes the comments for the config to ensure that they start in column // format fixes the comments for the config to ensure that they start in column
// 1 and are not followed by a space. // 1 and are not followed by a space.
func (t Toml) format(contents []byte) ([]byte, error) { func (t Toml) format(contents []byte) ([]byte, error) {
r, err := regexp.Compile(`(\n*)\s*?#\s*(\S.*)`) r := regexp.MustCompile(`(\n*)\s*?#\s*(\S.*)`)
if err != nil {
return nil, fmt.Errorf("unable to compile regexp: %v", err)
}
replaced := r.ReplaceAll(contents, []byte("$1#$2")) replaced := r.ReplaceAll(contents, []byte("$1#$2"))
return replaced, nil return replaced, nil

View File

@ -239,7 +239,7 @@ func newDRMDeviceFilter(logger logger.Interface, devices image.VisibleDevices, d
return nil, fmt.Errorf("failed to determine DRM devices for %v: %v", busID, err) return nil, fmt.Errorf("failed to determine DRM devices for %v: %v", busID, err)
} }
for _, drmDeviceNode := range drmDeviceNodes { for _, drmDeviceNode := range drmDeviceNodes {
filter[filepath.Join(drmDeviceNode)] = true filter[drmDeviceNode] = true
} }
} }

View File

@ -70,7 +70,7 @@ func (d *ipcMounts) Mounts() ([]Mount, error) {
var modifiedMounts []Mount var modifiedMounts []Mount
for _, m := range mounts { for _, m := range mounts {
mount := m mount := m
mount.Options = append(m.Options, "noexec") mount.Options = append(mount.Options, "noexec")
modifiedMounts = append(modifiedMounts, mount) modifiedMounts = append(modifiedMounts, mount)
} }

View File

@ -27,7 +27,7 @@ type list struct {
var _ Discover = (*list)(nil) var _ Discover = (*list)(nil)
// Merge creates a discoverer that is the composite of a list of discoveres. // Merge creates a discoverer that is the composite of a list of discoverers.
func Merge(d ...Discover) Discover { func Merge(d ...Discover) Discover {
l := list{ l := list{
discoverers: d, discoverers: d,

View File

@ -16,15 +16,6 @@
package dxcore package dxcore
import (
"github.com/NVIDIA/go-nvml/pkg/dl"
)
const (
libraryName = "libdxcore.so"
libraryLoadFlags = dl.RTLD_LAZY | dl.RTLD_GLOBAL
)
// dxcore stores a reference the dxcore dynamic library // dxcore stores a reference the dxcore dynamic library
var dxcore *context var dxcore *context

View File

@ -52,7 +52,9 @@ func (i additionalInfo) UsesNVGPUModule() (uses bool, reason string) {
if ret != nvml.SUCCESS { if ret != nvml.SUCCESS {
return false, fmt.Sprintf("failed to initialize nvml: %v", ret) return false, fmt.Sprintf("failed to initialize nvml: %v", ret)
} }
defer i.nvmllib.Shutdown() defer func() {
_ = i.nvmllib.Shutdown()
}()
var names []string var names []string

View File

@ -56,7 +56,7 @@ func ParseGPUInformationFile(path string) (GPUInfo, error) {
} }
// gpuInfoFrom parses a GPUInfo struct from the specified reader // gpuInfoFrom parses a GPUInfo struct from the specified reader
// An information file has the following strucutre: // An information file has the following structure:
// $ cat /proc/driver/nvidia/gpus/0000\:06\:00.0/information // $ cat /proc/driver/nvidia/gpus/0000\:06\:00.0/information
// Model: Tesla V100-SXM2-16GB // Model: Tesla V100-SXM2-16GB
// IRQ: 408 // IRQ: 408

View File

@ -234,7 +234,7 @@ func (c *ldcache) getEntries(selected func(string) bool) []entry {
return entries return entries
} }
// List creates a list of libraires in the ldcache. // List creates a list of libraries in the ldcache.
// The 32-bit and 64-bit libraries are returned separately. // The 32-bit and 64-bit libraries are returned separately.
func (c *ldcache) List() ([]string, []string) { func (c *ldcache) List() ([]string, []string) {
all := func(s string) bool { return true } all := func(s string) bool { return true }
@ -287,7 +287,7 @@ func (c *ldcache) resolveSelected(selected func(string) bool) ([]string, []strin
func (c *ldcache) resolve(target string) (string, error) { func (c *ldcache) resolve(target string) (string, error) {
name := filepath.Join(c.root, target) name := filepath.Join(c.root, target)
c.logger.Debugf("checking %v", string(name)) c.logger.Debugf("checking %v", name)
link, err := symlinks.Resolve(name) link, err := symlinks.Resolve(name)
if err != nil { if err != nil {

View File

@ -28,14 +28,8 @@ import (
"github.com/container-orchestrated-devices/container-device-interface/pkg/parser" "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
) )
type cdiModifier struct {
logger logger.Interface
specDirs []string
devices []string
}
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES enviroment variable is // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
// used to select the devices to include. // used to select the devices to include.
func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
devices, err := getDevicesFromSpec(logger, ociSpec, cfg) devices, err := getDevicesFromSpec(logger, ociSpec, cfg)

View File

@ -34,6 +34,7 @@ var _ oci.SpecModifier = (*fromCDISpec)(nil)
// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec. // Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec.
func (m fromCDISpec) Modify(spec *specs.Spec) error { func (m fromCDISpec) Modify(spec *specs.Spec) error {
for _, device := range m.cdiSpec.Devices { for _, device := range m.cdiSpec.Devices {
device := device
cdiDevice := cdi.Device{ cdiDevice := cdi.Device{
Device: &device, Device: &device,
} }

View File

@ -22,7 +22,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
@ -31,12 +30,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
) )
// csvMode represents the modifications as performed by the csv runtime mode
type csvMode struct {
logger logger.Interface
discoverer discover.Discover
}
const ( const (
visibleDevicesEnvvar = "NVIDIA_VISIBLE_DEVICES" visibleDevicesEnvvar = "NVIDIA_VISIBLE_DEVICES"
visibleDevicesVoid = "void" visibleDevicesVoid = "void"
@ -62,7 +55,7 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, image image.CUD
return nil, fmt.Errorf("failed to get list of CSV files: %v", err) return nil, fmt.Errorf("failed to get list of CSV files: %v", err)
} }
if nvidiaRequireJetpack, _ := image[nvidiaRequireJetpackEnvvar]; nvidiaRequireJetpack != "csv-mounts=all" { if nvidiaRequireJetpack := image[nvidiaRequireJetpackEnvvar]; nvidiaRequireJetpack != "csv-mounts=all" {
csvFiles = csv.BaseFilesOnly(csvFiles) csvFiles = csv.BaseFilesOnly(csvFiles)
} }

View File

@ -38,7 +38,7 @@ func NewGDSModifier(logger logger.Interface, cfg *config.Config, image image.CUD
return nil, nil return nil, nil
} }
if gds, _ := image[nvidiaGDSEnvvar]; gds != "enabled" { if gds := image[nvidiaGDSEnvvar]; gds != "enabled" {
return nil, nil return nil, nil
} }

View File

@ -49,6 +49,7 @@ func (m nvidiaContainerRuntimeHookRemover) Modify(spec *specs.Spec) error {
var newPrestart []specs.Hook var newPrestart []specs.Hook
for _, hook := range spec.Hooks.Prestart { for _, hook := range spec.Hooks.Prestart {
hook := hook
if isNVIDIAContainerRuntimeHook(&hook) { if isNVIDIAContainerRuntimeHook(&hook) {
m.logger.Debugf("Removing hook %v", hook) m.logger.Debugf("Removing hook %v", hook)
continue continue

View File

@ -38,7 +38,7 @@ func NewMOFEDModifier(logger logger.Interface, cfg *config.Config, image image.C
return nil, nil return nil, nil
} }
if mofed, _ := image[nvidiaMOFEDEnvvar]; mofed != "enabled" { if mofed := image[nvidiaMOFEDEnvvar]; mofed != "enabled" {
return nil, nil return nil, nil
} }

View File

@ -48,6 +48,7 @@ func (m stableRuntimeModifier) Modify(spec *specs.Spec) error {
// If an NVIDIA Container Runtime Hook already exists, we don't make any modifications to the spec. // If an NVIDIA Container Runtime Hook already exists, we don't make any modifications to the spec.
if spec.Hooks != nil { if spec.Hooks != nil {
for _, hook := range spec.Hooks.Prestart { for _, hook := range spec.Hooks.Prestart {
hook := hook
if isNVIDIAContainerRuntimeHook(&hook) { if isNVIDIAContainerRuntimeHook(&hook) {
m.logger.Infof("Existing nvidia prestart hook (%v) found in OCI spec", hook.Path) m.logger.Infof("Existing nvidia prestart hook (%v) found in OCI spec", hook.Path)
return nil return nil

View File

@ -150,6 +150,8 @@ func TestAddHookModifier(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
logHook.Reset() logHook.Reset()
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {

View File

@ -23,7 +23,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
) )
// pathRuntime wraps the path that a binary and defines the semanitcs for how to exec into it. // pathRuntime wraps the path that a binary and defines the semantics for how to exec into it.
// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the // This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the
// Runtime internface. // Runtime internface.
type pathRuntime struct { type pathRuntime struct {

View File

@ -27,6 +27,7 @@ type syscallExec struct{}
var _ Runtime = (*syscallExec)(nil) var _ Runtime = (*syscallExec)(nil)
func (r syscallExec) Exec(args []string) error { func (r syscallExec) Exec(args []string) error {
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
err := syscall.Exec(args[0], args, os.Environ()) err := syscall.Exec(args[0], args, os.Environ())
if err != nil { if err != nil {
return fmt.Errorf("could not exec '%v': %v", args[0], err) return fmt.Errorf("could not exec '%v': %v", args[0], err)

View File

@ -23,7 +23,7 @@ import (
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
) )
// SpecModifier defines an interace for modifying a (raw) OCI spec // SpecModifier defines an interface for modifying a (raw) OCI spec
type SpecModifier interface { type SpecModifier interface {
// Modify is a method that accepts a pointer to an OCI Srec and returns an // Modify is a method that accepts a pointer to an OCI Srec and returns an
// error. The intention is that the function would modify the spec in-place. // error. The intention is that the function would modify the spec in-place.

View File

@ -79,7 +79,7 @@ func (s *fileSpec) Modify(m SpecModifier) error {
return s.memorySpec.Modify(m) return s.memorySpec.Modify(m)
} }
// Flush writes the stored OCI specification to the filepath specifed by the path member. // Flush writes the stored OCI specification to the filepath specified by the path member.
// The file is truncated upon opening, overwriting any existing contents. // The file is truncated upon opening, overwriting any existing contents.
func (s fileSpec) Flush() error { func (s fileSpec) Flush() error {
if s.Spec == nil { if s.Spec == nil {

View File

@ -152,12 +152,13 @@ func TestModify(t *testing.T) {
err := spec.Modify(modifier{tc.modifierError}) err := spec.Modify(modifier{tc.modifierError})
if tc.spec == nil { switch {
case tc.spec == nil:
require.Error(t, err, "%d: %v", i, tc) require.Error(t, err, "%d: %v", i, tc)
} else if tc.modifierError != nil { case tc.modifierError != nil:
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc) require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc) require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
} else { default:
require.NoError(t, err, "%d: %v", i, tc) require.NoError(t, err, "%d: %v", i, tc)
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc) require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
} }

View File

@ -22,7 +22,8 @@ func TestMaintainSpec(t *testing.T) {
spec := NewFileSpec(inputSpecPath).(*fileSpec) spec := NewFileSpec(inputSpecPath).(*fileSpec)
spec.Load() _, err := spec.Load()
require.NoError(t, err)
outputSpecPath := filepath.Join(moduleRoot, "test/output", f) outputSpecPath := filepath.Join(moduleRoot, "test/output", f)
spec.path = outputSpecPath spec.path = outputSpecPath

View File

@ -70,6 +70,7 @@ func TestNewMountSpecFromLine(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
tc := tc
target, err := NewMountSpecFromLine(tc.line) target, err := NewMountSpecFromLine(tc.line)
if tc.expectedError != nil { if tc.expectedError != nil {
require.Error(t, err) require.Error(t, err)

View File

@ -80,19 +80,20 @@ func (d symlinkHook) getSpecificLinks() ([]string, error) {
lib := filepath.Base(m.Path) lib := filepath.Base(m.Path)
if strings.HasPrefix(lib, "libcuda.so") { switch {
case strings.HasPrefix(lib, "libcuda.so"):
// XXX Many applications wrongly assume that libcuda.so exists (e.g. with dlopen). // XXX Many applications wrongly assume that libcuda.so exists (e.g. with dlopen).
target = "libcuda.so.1" target = "libcuda.so.1"
link = "libcuda.so" link = "libcuda.so"
} else if strings.HasPrefix(lib, "libGLX_nvidia.so") { case strings.HasPrefix(lib, "libGLX_nvidia.so"):
// XXX GLVND requires this symlink for indirect GLX support. // XXX GLVND requires this symlink for indirect GLX support.
target = lib target = lib
link = "libGLX_indirect.so.0" link = "libGLX_indirect.so.0"
} else if strings.HasPrefix(lib, "libnvidia-opticalflow.so") { case strings.HasPrefix(lib, "libnvidia-opticalflow.so"):
// XXX Fix missing symlink for libnvidia-opticalflow.so. // XXX Fix missing symlink for libnvidia-opticalflow.so.
target = "libnvidia-opticalflow.so.1" target = "libnvidia-opticalflow.so.1"
link = "libnvidia-opticalflow.so" link = "libnvidia-opticalflow.so"
} else { default:
continue continue
} }
if linkProcessed[link] { if linkProcessed[link] {

View File

@ -51,11 +51,10 @@ func New(opts ...Option) (discover.Discover, error) {
} }
if o.symlinkLocator == nil { if o.symlinkLocator == nil {
searchPaths := append(o.librarySearchPaths, "/")
o.symlinkLocator = lookup.NewSymlinkLocator( o.symlinkLocator = lookup.NewSymlinkLocator(
lookup.WithLogger(o.logger), lookup.WithLogger(o.logger),
lookup.WithRoot(o.driverRoot), lookup.WithRoot(o.driverRoot),
lookup.WithSearchPaths(searchPaths...), lookup.WithSearchPaths(append(o.librarySearchPaths, "/")...),
) )
} }

View File

@ -57,7 +57,7 @@ func (c binary) eval() (bool, error) {
return false, err return false, err
} }
switch string(c.operator) { switch c.operator {
case equal: case equal:
return compare == 0, nil return compare == 0, nil
case notEqual: case notEqual:

View File

@ -16,8 +16,6 @@
package constraints package constraints
import "fmt"
const ( const (
equal = "=" equal = "="
notEqual = "!=" notEqual = "!="
@ -37,15 +35,3 @@ func (c always) Assert() error {
func (c always) String() string { func (c always) String() string {
return "true" return "true"
} }
// invalid is an invalid constraint and can never be met
type invalid string
func (c invalid) Assert() error {
return fmt.Errorf("invalid constraint: %v", c.String())
}
// String returns the string representation of the contraint
func (c invalid) String() string {
return string(c)
}

View File

@ -104,11 +104,12 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) {
newLogger.SetFormatter(new(logrus.JSONFormatter)) newLogger.SetFormatter(new(logrus.JSONFormatter))
} }
if len(logFiles) == 0 { switch len(logFiles) {
case 0:
newLogger.SetOutput(io.Discard) newLogger.SetOutput(io.Discard)
} else if len(logFiles) == 1 { case 1:
newLogger.SetOutput(logFiles[0]) newLogger.SetOutput(logFiles[0])
} else if len(logFiles) > 1 { default:
var writers []io.Writer var writers []io.Writer
for _, f := range logFiles { for _, f := range logFiles {
writers = append(writers, f) writers = append(writers, f)
@ -234,12 +235,13 @@ func parseArgs(args []string) loggerConfig {
} }
var value string var value string
if len(parts) == 2 { switch {
case len(parts) == 2:
value = parts[2] value = parts[2]
} else if i+1 < len(args) { case i+1 < len(args):
value = args[i+1] value = args[i+1]
i++ i++
} else { default:
continue continue
} }

View File

@ -18,6 +18,7 @@ package runtime
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
@ -53,7 +54,9 @@ func (r rt) Run(argv []string) (rerr error) {
if rerr != nil { if rerr != nil {
r.logger.Errorf("%v", rerr) r.logger.Errorf("%v", rerr)
} }
r.logger.Reset() if err := r.logger.Reset(); err != nil {
rerr = errors.Join(rerr, fmt.Errorf("failed to reset logger: %v", err))
}
}() }()
// We apply some config updates here to ensure that the config is valid in // We apply some config updates here to ensure that the config is valid in

View File

@ -38,8 +38,7 @@ func (c *ConfigV1) AddRuntime(name string, path string, setAsDefault bool) error
config.Set("version", int64(1)) config.Set("version", int64(1))
switch runc := config.GetPath([]string{"plugins", "cri", "containerd", "runtimes", "runc"}).(type) { if runc, ok := config.GetPath([]string{"plugins", "cri", "containerd", "runtimes", "runc"}).(*toml.Tree); ok {
case *toml.Tree:
runc, _ = toml.Load(runc.String()) runc, _ = toml.Load(runc.String())
config.SetPath([]string{"plugins", "cri", "containerd", "runtimes", name}, runc) config.SetPath([]string{"plugins", "cri", "containerd", "runtimes", name}, runc)
} }

View File

@ -32,8 +32,7 @@ func (c *Config) AddRuntime(name string, path string, setAsDefault bool) error {
config.Set("version", int64(2)) config.Set("version", int64(2))
switch runc := config.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "runtimes", "runc"}).(type) { if runc, ok := config.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "runtimes", "runc"}).(*toml.Tree); ok {
case *toml.Tree:
runc, _ = toml.Load(runc.String()) runc, _ = toml.Load(runc.String())
config.SetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "runtimes", name}, runc) config.SetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "runtimes", name}, runc)
} }

View File

@ -44,8 +44,7 @@ func (c *Config) AddRuntime(name string, path string, setAsDefault bool) error {
config := (toml.Tree)(*c) config := (toml.Tree)(*c)
switch runc := config.Get("crio.runtime.runtimes.runc").(type) { if runc, ok := config.Get("crio.runtime.runtimes.runc").(*toml.Tree); ok {
case *toml.Tree:
runc, _ = toml.Load(runc.String()) runc, _ = toml.Load(runc.String())
config.SetPath([]string{"crio", "runtime", "runtimes", name}, runc) config.SetPath([]string{"crio", "runtime", "runtimes", name}, runc)
} }

View File

@ -20,7 +20,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
@ -72,7 +71,7 @@ func (b *builder) loadConfig(config string) (*Config, error) {
} }
b.logger.Infof("Loading config from %v", config) b.logger.Infof("Loading config from %v", config)
readBytes, err := ioutil.ReadFile(config) readBytes, err := os.ReadFile(config)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read config: %v", err) return nil, fmt.Errorf("unable to read config: %v", err)
} }

View File

@ -34,9 +34,13 @@ import (
// The supplied NVML Library is used to query the expected driver version. // The supplied NVML Library is used to query the expected driver version.
func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { func NewDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) {
if r := nvmllib.Init(); r != nvml.SUCCESS { if r := nvmllib.Init(); r != nvml.SUCCESS {
return nil, fmt.Errorf("failed to initalize NVML: %v", r) return nil, fmt.Errorf("failed to initialize NVML: %v", r)
} }
defer nvmllib.Shutdown() defer func() {
if r := nvmllib.Shutdown(); r != nvml.SUCCESS {
logger.Warningf("failed to shutdown NVML: %v", r)
}
}()
version, r := nvmllib.SystemGetDriverVersion() version, r := nvmllib.SystemGetDriverVersion()
if r != nvml.SUCCESS { if r != nvml.SUCCESS {
@ -124,9 +128,9 @@ func getFirmwareSearchPaths(logger logger.Interface) ([]string, error) {
standardPaths := []string{ standardPaths := []string{
filepath.Join("/lib/firmware/updates/", utsRelease), filepath.Join("/lib/firmware/updates/", utsRelease),
filepath.Join("/lib/firmware/updates/"), "/lib/firmware/updates/",
filepath.Join("/lib/firmware/", utsRelease), filepath.Join("/lib/firmware/", utsRelease),
filepath.Join("/lib/firmware/"), "/lib/firmware/",
} }
return append(firmwarePaths, standardPaths...), nil return append(firmwarePaths, standardPaths...), nil

View File

@ -43,7 +43,11 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCT
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize dxcore: %v", err) return nil, fmt.Errorf("failed to initialize dxcore: %v", err)
} }
defer dxcore.Shutdown() defer func() {
if err := dxcore.Shutdown(); err != nil {
logger.Warningf("failed to shutdown dxcore: %v", err)
}
}()
driverStorePaths := dxcore.GetDriverStorePaths() driverStorePaths := dxcore.GetDriverStorePaths()
if len(driverStorePaths) == 0 { if len(driverStorePaths) == 0 {

View File

@ -41,9 +41,13 @@ func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
var deviceSpecs []specs.Device var deviceSpecs []specs.Device
if r := l.nvmllib.Init(); r != nvml.SUCCESS { if r := l.nvmllib.Init(); r != nvml.SUCCESS {
return nil, fmt.Errorf("failed to initalize NVML: %v", r) return nil, fmt.Errorf("failed to initialize NVML: %v", r)
} }
defer l.nvmllib.Shutdown() defer func() {
if r := l.nvmllib.Shutdown(); r != nvml.SUCCESS {
l.logger.Warningf("failed to shutdown NVML: %v", r)
}
}()
gpuDeviceSpecs, err := l.getGPUDeviceSpecs() gpuDeviceSpecs, err := l.getGPUDeviceSpecs()
if err != nil { if err != nil {

View File

@ -191,7 +191,11 @@ func (l *nvcdilib) getCudaVersion() (string, error) {
if r != nvml.SUCCESS { if r != nvml.SUCCESS {
return "", fmt.Errorf("failed to initialize nvml: %v", r) return "", fmt.Errorf("failed to initialize nvml: %v", r)
} }
defer l.nvmllib.Shutdown() defer func() {
if r := l.nvmllib.Shutdown(); r != nvml.SUCCESS {
l.logger.Warningf("failed to shutdown NVML: %v", r)
}
}()
version, r := l.nvmllib.SystemGetDriverVersion() version, r := l.nvmllib.SystemGetDriverVersion()
if r != nvml.SUCCESS { if r != nvml.SUCCESS {

View File

@ -104,13 +104,13 @@ type infoMock struct {
} }
func (i infoMock) HasDXCore() (bool, string) { func (i infoMock) HasDXCore() (bool, string) {
return bool(i.hasDXCore), "" return i.hasDXCore, ""
} }
func (i infoMock) HasNvml() (bool, string) { func (i infoMock) HasNvml() (bool, string) {
return bool(i.hasNVML), "" return i.hasNVML, ""
} }
func (i infoMock) IsTegraSystem() (bool, string) { func (i infoMock) IsTegraSystem() (bool, string) {
return bool(i.isTegra), "" return i.isTegra, ""
} }

View File

@ -22,6 +22,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
) )
@ -47,7 +48,7 @@ func newBuilder(opts ...Option) *builder {
} }
if s.raw != nil { if s.raw != nil {
s.noSimplify = true s.noSimplify = true
vendor, class := cdi.ParseQualifier(s.raw.Kind) vendor, class := parser.ParseQualifier(s.raw.Kind)
s.vendor = vendor s.vendor = vendor
s.class = class s.class = class
} }
@ -85,7 +86,7 @@ func (o *builder) Build() (*spec, error) {
if raw.Version == DetectMinimumVersion { if raw.Version == DetectMinimumVersion {
minVersion, err := cdi.MinimumRequiredVersion(raw) minVersion, err := cdi.MinimumRequiredVersion(raw)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get minumum required CDI spec version: %v", err) return nil, fmt.Errorf("failed to get minimum required CDI spec version: %v", err)
} }
raw.Version = minVersion raw.Version = minVersion
} }

View File

@ -39,6 +39,7 @@ func (d dedupe) Transform(spec *specs.Spec) error {
} }
var updatedDevices []specs.Device var updatedDevices []specs.Device
for _, device := range spec.Devices { for _, device := range spec.Devices {
device := device
if err := d.transformEdits(&device.ContainerEdits); err != nil { if err := d.transformEdits(&device.ContainerEdits); err != nil {
return err return err
} }

View File

@ -20,7 +20,9 @@ import (
"fmt" "fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
"github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/container-orchestrated-devices/container-device-interface/specs-go"
) )
@ -64,7 +66,7 @@ func NewMergedDevice(opts ...MergedDeviceOption) (Transformer, error) {
} }
m.simplifier = NewSimplifier() m.simplifier = NewSimplifier()
if err := cdi.ValidateDeviceName(m.name); err != nil { if err := parser.ValidateDeviceName(m.name); err != nil {
return nil, fmt.Errorf("invalid device name %q: %v", m.name, err) return nil, fmt.Errorf("invalid device name %q: %v", m.name, err)
} }
@ -109,6 +111,7 @@ func mergeDeviceSpecs(deviceSpecs []specs.Device, mergedDeviceName string) (*spe
mergedEdits := edits.NewContainerEdits() mergedEdits := edits.NewContainerEdits()
for _, d := range deviceSpecs { for _, d := range deviceSpecs {
d := d
edit := cdi.ContainerEdits{ edit := cdi.ContainerEdits{
ContainerEdits: &d.ContainerEdits, ContainerEdits: &d.ContainerEdits,
} }

View File

@ -39,6 +39,7 @@ func (r remove) Transform(spec *specs.Spec) error {
} }
for _, device := range spec.Devices { for _, device := range spec.Devices {
device := device
if err := r.transformEdits(&device.ContainerEdits); err != nil { if err := r.transformEdits(&device.ContainerEdits); err != nil {
return fmt.Errorf("failed to remove edits from device %q: %w", device.Name, err) return fmt.Errorf("failed to remove edits from device %q: %w", device.Name, err)
} }

View File

@ -54,6 +54,7 @@ func (t rootTransformer) Transform(spec *specs.Spec) error {
} }
for _, d := range spec.Devices { for _, d := range spec.Devices {
d := d
if err := t.applyToEdits(&d.ContainerEdits); err != nil { if err := t.applyToEdits(&d.ContainerEdits); err != nil {
return fmt.Errorf("failed to apply root transform to device %s: %w", d.Name, err) return fmt.Errorf("failed to apply root transform to device %s: %w", d.Name, err)
} }

View File

@ -44,6 +44,7 @@ func (d sorter) Transform(spec *specs.Spec) error {
} }
var updatedDevices []specs.Device var updatedDevices []specs.Device
for _, device := range spec.Devices { for _, device := range spec.Devices {
device := device
if err := d.transformEdits(&device.ContainerEdits); err != nil { if err := d.transformEdits(&device.ContainerEdits); err != nil {
return err return err
} }

View File

@ -157,6 +157,7 @@ func (o Options) SystemdRestart(service string) error {
logrus.Infof("Restarting %v%v using systemd: %v", service, msg, args) logrus.Infof("Restarting %v%v using systemd: %v", service, msg, args)
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmd := exec.Command(args[0], args[1:]...) cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr

View File

@ -90,7 +90,6 @@ func TestUpdateV2ConfigDefaultRuntime(t *testing.T) {
func TestUpdateV2Config(t *testing.T) { func TestUpdateV2Config(t *testing.T) {
const runtimeDir = "/test/runtime/dir" const runtimeDir = "/test/runtime/dir"
const expectedVersion = int64(2)
testCases := []struct { testCases := []struct {
runtimeName string runtimeName string

View File

@ -235,6 +235,8 @@ func TestUpdateConfig(t *testing.T) {
} }
for i, tc := range testCases { for i, tc := range testCases {
tc := tc
o := &options{ o := &options{
Options: container.Options{ Options: container.Options{
RuntimeName: tc.runtimeName, RuntimeName: tc.runtimeName,
@ -361,6 +363,7 @@ func TestRevertConfig(t *testing.T) {
} }
for i, tc := range testCases { for i, tc := range testCases {
tc := tc
o := &options{} o := &options{}
err := o.RevertConfig(&tc.config) err := o.RevertConfig(&tc.config)

View File

@ -65,7 +65,7 @@ func main() {
&cli.BoolFlag{ &cli.BoolFlag{
Name: "no-daemon", Name: "no-daemon",
Aliases: []string{"n"}, Aliases: []string{"n"},
Usage: "terminate immediatly after setting up the runtime. Note that no cleanup will be performed", Usage: "terminate immediately after setting up the runtime. Note that no cleanup will be performed",
Destination: &options.noDaemon, Destination: &options.noDaemon,
EnvVars: []string{"NO_DAEMON"}, EnvVars: []string{"NO_DAEMON"},
}, },
@ -229,6 +229,7 @@ func installToolkit(o *options) error {
filepath.Join(o.root, toolkitSubDir), filepath.Join(o.root, toolkitSubDir),
} }
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmd := exec.Command("sh", "-c", strings.Join(cmdline, " ")) cmd := exec.Command("sh", "-c", strings.Join(cmdline, " "))
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@ -247,6 +248,7 @@ func setupRuntime(o *options) error {
cmdline := fmt.Sprintf("%v setup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir) cmdline := fmt.Sprintf("%v setup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir)
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmd := exec.Command("sh", "-c", cmdline) cmd := exec.Command("sh", "-c", cmdline)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@ -272,6 +274,7 @@ func cleanupRuntime(o *options) error {
cmdline := fmt.Sprintf("%v cleanup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir) cmdline := fmt.Sprintf("%v cleanup %v %v\n", o.runtime, o.runtimeArgs, toolkitDir)
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
cmd := exec.Command("sh", "-c", cmdline) cmd := exec.Command("sh", "-c", cmdline)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr

View File

@ -43,7 +43,7 @@ func installContainerRuntimes(toolkitDir string, driverRoot string) error {
} }
// newNVidiaContainerRuntimeInstaller returns a new executable installer for the NVIDIA container runtime. // newNVidiaContainerRuntimeInstaller returns a new executable installer for the NVIDIA container runtime.
// This installer will copy the specified source exectuable to the toolkit directory. // This installer will copy the specified source executable to the toolkit directory.
// The executable is copied to a file with the same name as the source, but with a ".real" suffix and a wrapper is // The executable is copied to a file with the same name as the source, but with a ".real" suffix and a wrapper is
// created to allow for the configuration of the runtime environment. // created to allow for the configuration of the runtime environment.
func newNvidiaContainerRuntimeInstaller(source string) *executable { func newNvidiaContainerRuntimeInstaller(source string) *executable {
@ -82,16 +82,3 @@ func newRuntimeInstaller(source string, target executableTarget, env map[string]
return &r return &r
} }
func findLibraryRoot(root string) (string, error) {
libnvidiamlPath, err := findManagementLibrary(root)
if err != nil {
return "", fmt.Errorf("error locating NVIDIA management library: %v", err)
}
return filepath.Dir(libnvidiamlPath), nil
}
func findManagementLibrary(root string) (string, error) {
return findLibrary(root, "libnvidia-ml.so")
}

View File

@ -27,6 +27,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
"github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
toml "github.com/pelletier/go-toml" toml "github.com/pelletier/go-toml"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -238,11 +239,11 @@ func validateOptions(c *cli.Context, opts *options) error {
return fmt.Errorf("invalid --toolkit-root option: %v", opts.toolkitRoot) return fmt.Errorf("invalid --toolkit-root option: %v", opts.toolkitRoot)
} }
vendor, class := cdi.ParseQualifier(opts.cdiKind) vendor, class := parser.ParseQualifier(opts.cdiKind)
if err := cdi.ValidateVendorName(vendor); err != nil { if err := parser.ValidateVendorName(vendor); err != nil {
return fmt.Errorf("invalid CDI vendor name: %v", err) return fmt.Errorf("invalid CDI vendor name: %v", err)
} }
if err := cdi.ValidateClassName(class); err != nil { if err := parser.ValidateClassName(class); err != nil {
return fmt.Errorf("invalid CDI class name: %v", err) return fmt.Errorf("invalid CDI class name: %v", err)
} }
opts.cdiVendor = vendor opts.cdiVendor = vendor
@ -454,13 +455,14 @@ func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContai
config.Set(key, value) config.Set(key, value)
} }
_, err = config.WriteTo(targetConfig) if _, err := config.WriteTo(targetConfig); err != nil {
if err != nil {
return fmt.Errorf("error writing config: %v", err) return fmt.Errorf("error writing config: %v", err)
} }
os.Stdout.WriteString("Using config:\n") os.Stdout.WriteString("Using config:\n")
config.WriteTo(os.Stdout) if _, err = config.WriteTo(os.Stdout); err != nil {
log.Warningf("Failed to output config to STDOUT: %v", err)
}
return nil return nil
} }