Merge branch 'migrate-experimental-runtime' into 'experimental'

Merge code from experimental runtime

See merge request nvidia/container-toolkit/container-toolkit!37
This commit is contained in:
Evan Lezar
2021-07-05 11:49:44 +00:00
491 changed files with 278175 additions and 185 deletions

91
.common-ci.yml Normal file
View File

@@ -0,0 +1,91 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
default:
image: docker:stable
services:
- name: docker:stable-dind
command: ["--experimental"]
variables:
IMAGE: "${CI_REGISTRY_IMAGE}"
IMAGE_TAG: "${CI_COMMIT_REF_SLUG}"
build-dev-image:
stage: image
before_script:
- docker login -u "${CI_REGISTRY_USER}" -p "${CI_REGISTRY_PASSWORD}" "${CI_REGISTRY}"
script:
- apk --no-cache add make bash
- make .build-image
- make .push-build-image
.requires-build-image:
variables:
SKIP_IMAGE_BUILD: "yes"
before_script:
- apk --no-cache add make bash
- docker login -u "${CI_REGISTRY_USER}" -p "${CI_REGISTRY_PASSWORD}" "${CI_REGISTRY}"
- make .pull-build-image
.go-check:
extends:
- .requires-build-image
stage: go-checks
fmt:
extends:
- .go-check
script:
- make docker-assert-fmt
vet:
extends:
- .go-check
script:
- make docker-vet
lint:
extends:
- .go-check
script:
- make docker-lint
allow_failure: true
ineffassign:
extends:
- .go-check
script:
- make docker-ineffassign
allow_failure: true
misspell:
extends:
- .go-check
script:
- make docker-misspell
go-build:
extends:
- .requires-build-image
stage: go-build
script:
- make docker-build
unit-tests:
extends:
- .requires-build-image
stage: unit-tests
script:
- make docker-coverage

View File

@@ -12,86 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Build packages for all supported OS / ARCH combinations
include:
- .common-ci.yml
stages:
- tests
- image
- go-checks
- go-build
- unit-tests
- test
- scan
- release
- build-one
- build-all
.tests-setup: &tests-setup
image: golang:1.14.4
rules:
- when: always
variables:
GITHUB_ROOT: "github.com/NVIDIA"
PROJECT_GOPATH: "${GITHUB_ROOT}/nvidia-container-toolkit"
before_script:
- mkdir -p ${GOPATH}/src/${GITHUB_ROOT}
- ln -s ${CI_PROJECT_DIR} ${GOPATH}/src/${PROJECT_GOPATH}
.build-setup: &build-setup
image: docker:19.03.8
services:
- name: docker:19.03.8-dind
command: ["--experimental"]
before_script:
- apk update
- apk upgrade
- apk add coreutils build-base sed git bash make
- docker run --rm --privileged multiarch/qemu-user-static --reset -p yes -c yes
# Run a series of sanity-check tests over the code
lint:
<<: *tests-setup
stage: tests
script:
- GO111MODULE=off go get -u golang.org/x/lint/golint
- make lint
vet:
<<: *tests-setup
stage: tests
script:
- make vet
unit_test:
<<: *tests-setup
stage: tests
script:
- make test
coverage:
<<: *tests-setup
stage: tests
script:
- make coverage
fmt:
<<: *tests-setup
stage: tests
script:
- make assert-fmt
ineffassign:
<<: *tests-setup
stage: tests
script:
- GO111MODULE=off go get -u github.com/gordonklaus/ineffassign
- make ineffassign
misspell:
<<: *tests-setup
stage: tests
script:
- GO111MODULE=off go get -u github.com/client9/misspell/cmd/misspell
- make misspell
# build-one jobs build packages for a single OS / ARCH combination.
#
# They are run during the first stage of the pipeline as a smoke test to ensure

View File

@@ -1,4 +1,4 @@
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,14 +27,40 @@ MODULE := github.com/NVIDIA/nvidia-container-toolkit
docker-native:
include $(CURDIR)/docker/docker.mk
ifeq ($(IMAGE),)
REGISTRY ?= nvidia
IMAGE=$(REGISTRY)/container-toolkit
endif
IMAGE_TAG ?= $(GOLANG_VERSION)
BUILDIMAGE ?= $(IMAGE):$(IMAGE_TAG)-devel
EXAMPLES := $(patsubst ./examples/%/,%,$(sort $(dir $(wildcard ./examples/*/))))
EXAMPLE_TARGETS := $(patsubst %,example-%, $(EXAMPLES))
CHECK_TARGETS := assert-fmt vet lint ineffassign misspell
MAKE_TARGETS := binary build all check fmt lint-internal test examples coverage generate $(CHECK_TARGETS)
TARGETS := $(MAKE_TARGETS) $(EXAMPLE_TARGETS)
DOCKER_TARGETS := $(patsubst %,docker-%, $(TARGETS))
.PHONY: $(TARGETS) $(DOCKER_TARGETS)
GOOS ?= linux
binary:
GOOS=$(GOOS) go build -ldflags "-s -w" -o "$(LIB_NAME)" $(MODULE)/cmd/$(LIB_NAME)
# Define the check targets for the Golang codebase
.PHONY: check fmt assert-fmt ineffassign lint misspell vet
check: assert-fmt lint misspell vet
build:
GOOS=$(GOOS) go build ./...
examples: $(EXAMPLE_TARGETS)
$(EXAMPLE_TARGETS): example-%:
GOOS=$(GOOS) go build ./examples/$(*)
all: check test build binary
check: $(CHECK_TARGETS)
# Apply go fmt to the codebase
fmt:
go list -f '{{.Dir}}' $(MODULE)/... \
| xargs gofmt -s -l -w
@@ -55,8 +81,12 @@ ineffassign:
ineffassign $(MODULE)/...
lint:
# We use `go list -f '{{.Dir}}' $(GOLANG_PKG_PATH)/...` to skip the `vendor` folder.
go list -f '{{.Dir}}' $(MODULE)/... | xargs golint -set_exit_status
# We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder.
go list -f '{{.Dir}}' $(MODULE)/... | grep -v /internal/ | xargs golint -set_exit_status
lint-internal:
# We use `go list -f '{{.Dir}}' $(MODULE)/...` to skip the `vendor` folder.
go list -f '{{.Dir}}' $(MODULE)/internal/... | xargs golint -set_exit_status
misspell:
misspell $(MODULE)/...
@@ -65,8 +95,42 @@ vet:
go vet $(MODULE)/...
COVERAGE_FILE := coverage.out
test:
go test -coverprofile=$(COVERAGE_FILE) $(MODULE)/...
test: build
go test -v -coverprofile=$(COVERAGE_FILE) $(MODULE)/...
coverage: test
go tool cover -func=$(COVERAGE_FILE)
cat $(COVERAGE_FILE) | grep -v "_mock.go" > $(COVERAGE_FILE).no-mocks
go tool cover -func=$(COVERAGE_FILE).no-mocks
generate:
go generate $(MODULE)/...
# Generate an image for containerized builds
# Note: This image is local only
.PHONY: .build-image .pull-build-image .push-build-image
.build-image: docker/Dockerfile.devel
if [ x"$(SKIP_IMAGE_BUILD)" = x"" ]; then \
$(DOCKER) build \
--progress=plain \
--build-arg GOLANG_VERSION="$(GOLANG_VERSION)" \
--tag $(BUILDIMAGE) \
-f $(^) \
docker; \
fi
.pull-build-image:
$(DOCKER) pull $(BUILDIMAGE)
.push-build-image:
$(DOCKER) push $(BUILDIMAGE)
$(DOCKER_TARGETS): docker-%: .build-image
@echo "Running 'make $(*)' in docker container $(BUILDIMAGE)"
$(DOCKER) run \
--rm \
-e GOCACHE=/tmp/.cache \
-v $(PWD):$(PWD) \
-w $(PWD) \
--user $$(id -u):$$(id -g) \
$(BUILDIMAGE) \
make $(*)

View File

@@ -0,0 +1,52 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
log "github.com/sirupsen/logrus"
)
// list represents a list of config updaters that are applied in order, with later
// configs having preference.
type list struct {
logger *log.Logger
configs []configUpdater
}
var _ configUpdater = (*list)(nil)
func newListWithLogger(logger *log.Logger, ci ...configUpdater) configUpdater {
c := list{
logger: logger,
configs: ci,
}
return &c
}
func (c list) Update(cfg *Config) error {
for i, u := range c.configs {
c.logger.Debugf("Applying config %v: %v", i, u)
err := u.Update(cfg)
if err != nil {
return fmt.Errorf("error applying config %v: %v", i, err)
}
}
return nil
}

View File

@@ -0,0 +1,100 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestComposite(t *testing.T) {
logger, _ := testlog.NewNullLogger()
// Empty list
c := newListWithLogger(logger)
cfg := &Config{}
err := c.Update(cfg)
require.NoError(t, err)
require.EqualValues(t, &Config{}, cfg)
// Add a single mock config
mockConfigUpdater := &configUpdaterMock{
UpdateFunc: func(cfg *Config) error {
cfg.DebugFilePath = "updated"
return nil
},
}
c = newListWithLogger(logger, mockConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.NoError(t, err)
require.EqualValues(t, &Config{DebugFilePath: "updated"}, cfg)
require.Len(t, mockConfigUpdater.UpdateCalls(), 1)
// Reset the calls
mockConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
// define an error config
errorConfigUpdater := &configUpdaterMock{
UpdateFunc: func(cfg *Config) error {
return fmt.Errorf("mock error")
},
}
c = newListWithLogger(logger, errorConfigUpdater, mockConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.Error(t, err)
require.EqualValues(t, &Config{}, cfg)
require.Len(t, errorConfigUpdater.UpdateCalls(), 1)
require.Len(t, mockConfigUpdater.UpdateCalls(), 0)
// Reset the calls
mockConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
errorConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
// Change the order of the config and error
c = newListWithLogger(logger, mockConfigUpdater, errorConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.Error(t, err)
require.EqualValues(t, &Config{DebugFilePath: "updated"}, cfg)
require.Len(t, errorConfigUpdater.UpdateCalls(), 1)
require.Len(t, mockConfigUpdater.UpdateCalls(), 1)
// Reset the calls
mockConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
errorConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
// Call the mock twice
c = newListWithLogger(logger, mockConfigUpdater, mockConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.NoError(t, err)
require.EqualValues(t, &Config{DebugFilePath: "updated"}, cfg)
require.Len(t, mockConfigUpdater.UpdateCalls(), 2)
}

View File

@@ -0,0 +1,63 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
log "github.com/sirupsen/logrus"
)
// Config defines the configuration options for the NVIDIA Container Runtime
type Config struct {
// Root defines the root of the file system to be used for locating mounts
Root string
// DebugFilePath defines a log file to print debug output to
DebugFilePath string
// RuntimePath defines the path to an OCI compliant runtime
RuntimePath string
// LogLevel defines the logging level for the application
LogLevel string
}
//go:generate moq -stub -out config_mock.go . configUpdater
// configUpdate represents an interface for applying updates to a config.
type configUpdater interface {
Update(*Config) error
}
// GetConfig returns a config struct with the values resolved. The values are defined in order of
// priority:
// 1. From the associated environment variables
// 2. From the loaded config file
// 3. From the default values defined in the `defaultConfig` function
func GetConfig(logger *log.Logger) (*Config, error) {
cfg := &Config{}
configs := newListWithLogger(logger,
newDefaultConfig(),
newDefaultConfigFileWithLogger(logger),
newConfigFromEnvironment(),
)
err := configs.Update(cfg)
if err != nil {
return nil, fmt.Errorf("error getting config: %v", err)
}
return cfg, nil
}

View File

@@ -0,0 +1,76 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package config
import (
"sync"
)
// Ensure, that configUpdaterMock does implement configUpdater.
// If this is not the case, regenerate this file with moq.
var _ configUpdater = &configUpdaterMock{}
// configUpdaterMock is a mock implementation of configUpdater.
//
// func TestSomethingThatUsesconfigUpdater(t *testing.T) {
//
// // make and configure a mocked configUpdater
// mockedconfigUpdater := &configUpdaterMock{
// UpdateFunc: func(config *Config) error {
// panic("mock out the Update method")
// },
// }
//
// // use mockedconfigUpdater in code that requires configUpdater
// // and then make assertions.
//
// }
type configUpdaterMock struct {
// UpdateFunc mocks the Update method.
UpdateFunc func(config *Config) error
// calls tracks calls to the methods.
calls struct {
// Update holds details about calls to the Update method.
Update []struct {
// Config is the config argument value.
Config *Config
}
}
lockUpdate sync.RWMutex
}
// Update calls UpdateFunc.
func (mock *configUpdaterMock) Update(config *Config) error {
callInfo := struct {
Config *Config
}{
Config: config,
}
mock.lockUpdate.Lock()
mock.calls.Update = append(mock.calls.Update, callInfo)
mock.lockUpdate.Unlock()
if mock.UpdateFunc == nil {
var (
errOut error
)
return errOut
}
return mock.UpdateFunc(config)
}
// UpdateCalls gets all the calls that were made to Update.
// Check the length with:
// len(mockedconfigUpdater.UpdateCalls())
func (mock *configUpdaterMock) UpdateCalls() []struct {
Config *Config
} {
var calls []struct {
Config *Config
}
mock.lockUpdate.RLock()
calls = mock.calls.Update
mock.lockUpdate.RUnlock()
return calls
}

View File

@@ -0,0 +1,58 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"io/ioutil"
"os"
"path"
"path/filepath"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestGetConfigWithCustomConfig(t *testing.T) {
wd, err := os.Getwd()
require.NoError(t, err)
// By default debug is disabled
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")
testDir := path.Join(wd, "test")
filename := path.Join(testDir, configFileRelativePath)
previousConfig, present := os.LookupEnv(configOverride)
os.Setenv(configOverride, testDir)
defer func() {
if present {
os.Setenv(configOverride, previousConfig)
} else {
os.Unsetenv(configOverride)
}
}()
require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
require.NoError(t, ioutil.WriteFile(filename, contents, 0766))
defer func() { require.NoError(t, os.RemoveAll(testDir)) }()
logger, _ := testlog.NewNullLogger()
cfg, err := GetConfig(logger)
require.NoError(t, err)
require.Equal(t, "/nvidia-container-toolkit.log", cfg.DebugFilePath)
}

View File

@@ -0,0 +1,47 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
log "github.com/sirupsen/logrus"
)
type defaultConfig struct{}
var _ configUpdater = (*defaultConfig)(nil)
func newDefaultConfig() configUpdater {
c := defaultConfig{}
return &c
}
// Update defines the the default values for the config options
func (c defaultConfig) Update(cfg *Config) error {
*cfg = Config{
DebugFilePath: "/dev/null",
LogLevel: log.InfoLevel.String(),
}
return nil
}
func getDefaultConfig() *Config {
cfg := &Config{}
defaultConfig{}.Update(cfg)
return cfg
}

View File

@@ -0,0 +1,37 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultConfigUpdate(t *testing.T) {
cfg := &Config{}
require.Empty(t, cfg.DebugFilePath)
c := defaultConfig{}
err := c.Update(cfg)
require.NoError(t, err)
require.Equal(t, "/dev/null", cfg.DebugFilePath)
require.Equal(t, "", cfg.Root)
}

View File

@@ -0,0 +1,61 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"os"
"strings"
)
const (
debugFilePathEnvvarName = "NVIDIA_CONTAINER_RUNTIME_DEBUG"
runtimePathEnvvarName = "NVIDIA_CONTAINER_RUNTIME_PATH"
rootEnvvarName = "NVIDIA_CONTAINER_RUNTIME_ROOT"
logLevelEnvvarName = "NVIDIA_CONTAINER_RUNTIME_LOG_LEVEL"
)
type envConfig struct{}
func newConfigFromEnvironment() configUpdater {
c := envConfig{}
return &c
}
func (c envConfig) Update(cfg *Config) error {
debugFilePathEnvvar, exists := os.LookupEnv(debugFilePathEnvvarName)
if exists && strings.TrimSpace(debugFilePathEnvvar) != "" {
cfg.DebugFilePath = debugFilePathEnvvar
}
runtimePathEnvvar, exists := os.LookupEnv(runtimePathEnvvarName)
if exists && strings.TrimSpace(runtimePathEnvvar) != "" {
cfg.RuntimePath = runtimePathEnvvar
}
rootEnvvar, exists := os.LookupEnv(rootEnvvarName)
if exists && strings.TrimSpace(rootEnvvar) != "" {
cfg.Root = rootEnvvar
}
logLevelEnvvar, exists := os.LookupEnv(logLevelEnvvarName)
if exists && strings.TrimSpace(logLevelEnvvar) != "" {
cfg.LogLevel = logLevelEnvvar
}
return nil
}

View File

@@ -0,0 +1,29 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
// noop implements a configUpdater that does not update a config.
type noop struct{}
func newNoopConfigUpdater() configUpdater {
c := noop{}
return &c
}
func (c noop) Update(cfg *Config) error {
return nil
}

View File

@@ -0,0 +1,36 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNoopDoesNotModifyConfig(t *testing.T) {
cfg := &Config{
DebugFilePath: "test-path",
}
c := newNoopConfigUpdater()
err := c.Update(cfg)
require.NoError(t, err)
require.Equal(t, "test-path", cfg.DebugFilePath)
}

View File

@@ -0,0 +1,158 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/pelletier/go-toml"
log "github.com/sirupsen/logrus"
)
const (
configFileRelativePath = "nvidia-container-runtime/config.toml"
configOverride = "XDG_CONFIG_HOME"
defaultConfigRoot = "/etc"
nvidiaContainerCliSection = "nvidia-container-cli"
nvidiaContainerRuntimeConfigSection = "nvidia-container-runtime"
nvidiaContainerRuntimeExperimentalConfigSection = "nvidia-container-runtime.experimental"
)
type tomlConfig struct {
logger *log.Logger
path string
sections []tomlSection
}
type tomlSection struct {
section string
keys map[string]struct{}
}
func newDefaultConfigFileWithLogger(logger *log.Logger) configUpdater {
configDir := defaultConfigRoot
if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 {
configDir = XDGConfigDir
}
configFilePath := filepath.Join(configDir, configFileRelativePath)
return newConfigFromFileWithLogger(logger, configFilePath)
}
func newConfigFromFileWithLogger(logger *log.Logger, filepath string) configUpdater {
if _, err := os.Stat(filepath); os.IsNotExist(err) {
logger.Warnf("The config file '%v' does not exist", filepath)
return newNoopConfigUpdater()
}
sections := []tomlSection{
{
section: nvidiaContainerRuntimeConfigSection,
},
{
section: nvidiaContainerRuntimeExperimentalConfigSection,
},
{
section: nvidiaContainerCliSection,
keys: map[string]struct{}{
"root": {},
},
},
}
c := tomlConfig{
logger: logger,
path: filepath,
sections: sections,
}
return &c
}
func (c tomlConfig) Update(cfg *Config) error {
configFile, err := os.Open(c.path)
if err != nil {
return fmt.Errorf("error opening config file %v: %v", c.path, err)
}
defer configFile.Close()
return c.updateFromReader(cfg, configFile)
}
func (c tomlConfig) updateFromReader(cfg *Config, reader io.Reader) error {
toml, err := toml.LoadReader(reader)
if err != nil {
return fmt.Errorf("error reading TOML contents: %v", err)
}
for _, section := range c.sections {
if v, ok := section.GetStringFrom(toml, "debug"); ok {
cfg.DebugFilePath = v
}
if v, ok := section.GetStringFrom(toml, "runtime-path"); ok {
cfg.RuntimePath = v
}
if v, ok := section.GetStringFrom(toml, "root"); ok {
cfg.Root = v
}
if v, ok := section.GetStringFrom(toml, "log-level"); ok {
cfg.Root = v
}
}
return nil
}
func (c tomlSection) GetStringFrom(toml *toml.Tree, key string) (string, bool) {
value := c.GetFrom(toml, key)
if value != nil {
if v, ok := value.(string); ok {
return v, ok
}
}
return "", false
}
func (c tomlSection) GetFrom(toml *toml.Tree, key string) interface{} {
if !c.validKey(key) {
return nil
}
return toml.Get(c.configKey(key))
}
func (c tomlSection) validKey(key string) bool {
if c.keys == nil {
return true
}
_, exists := c.keys[key]
return exists
}
func (c tomlSection) configKey(key string) string {
if c.section == "" {
return key
}
return c.section + "." + key
}

View File

@@ -0,0 +1,249 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
"testing"
"testing/iotest"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestUpdateFromReader(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
description string
readerError bool
lines []string
expected *Config
expectedError bool
}{
{
description: "reader error returns error",
readerError: true,
expectedError: true,
expected: getDefaultConfig(),
},
{
description: "empty config returns defaults",
lines: []string{},
expected: &Config{
DebugFilePath: "/dev/null",
LogLevel: "info",
},
},
{
description: "debug output is set",
lines: []string{"nvidia-container-runtime.debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
description: "debug output is set in section",
lines: []string{"[nvidia-container-runtime]", "debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
description: "blank debug is set",
lines: []string{"nvidia-container-runtime.debug=\"\""},
expected: &Config{
DebugFilePath: "",
LogLevel: "info",
},
},
{
description: "non-string debug is ignored",
lines: []string{"nvidia-container-runtime.debug=2"},
expected: &Config{
DebugFilePath: "/dev/null",
LogLevel: "info",
},
},
}
for i, tc := range testCases {
cfg := getDefaultConfig()
c := tomlConfig{
logger: logger,
sections: []tomlSection{
{section: nvidiaContainerRuntimeConfigSection},
},
}
var reader io.Reader
if tc.readerError {
reader = iotest.ErrReader(fmt.Errorf("error"))
} else {
reader = strings.NewReader(strings.Join(tc.lines, "\n"))
}
err := c.updateFromReader(cfg, reader)
if tc.expectedError {
require.Error(t, err, "%d: %v", i, tc.description)
} else {
require.NoError(t, err, "%d: %v", i, tc.description)
}
require.EqualValues(t, tc.expected, cfg, "%d: %v", i, tc.description)
}
}
func TestUpdateFromReaderExperimental(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
readerError bool
lines []string
expected *Config
expectedError bool
}{
{
lines: []string{"nvidia-container-runtime.debug=\"nvidia-container-toolkit.log\""},
expected: getDefaultConfig(),
},
{
lines: []string{"[nvidia-container-runtime]", "debug=\"nvidia-container-toolkit.log\""},
expected: getDefaultConfig(),
},
{
lines: []string{"nvidia-container-runtime.experimental.debug=\"\""},
expected: &Config{
DebugFilePath: "",
LogLevel: "info",
},
},
{
lines: []string{"nvidia-container-runtime.experimental.debug=2"},
expected: getDefaultConfig(),
},
{
lines: []string{"nvidia-container-runtime.experimental.debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
lines: []string{"[nvidia-container-runtime.experimental]", "debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
lines: []string{
"nvidia-container-runtime.debug=\"nvidia-container-toolkit.log\"",
"nvidia-container-runtime.experimental.debug=\"nvidia-container-exp-toolkit.log\"",
},
expected: &Config{
DebugFilePath: "nvidia-container-exp-toolkit.log",
LogLevel: "info",
},
},
}
for i, tc := range testCases {
cfg := getDefaultConfig()
c := tomlConfig{
logger: logger,
sections: []tomlSection{
{section: nvidiaContainerRuntimeExperimentalConfigSection},
},
}
var reader io.Reader
if tc.readerError {
reader = iotest.ErrReader(fmt.Errorf("error"))
} else {
reader = strings.NewReader(strings.Join(tc.lines, "\n"))
}
err := c.updateFromReader(cfg, reader)
if tc.expectedError {
require.Error(t, err, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
}
require.EqualValues(t, tc.expected, cfg, "%d: %v", i, tc)
}
}
func TestConfigFromFile(t *testing.T) {
wd, err := os.Getwd()
require.NoError(t, err)
// By default debug is disabled
lines := []string{
"[nvidia-container-cli]",
"root = \"/run/nvidia/driver\"",
"[nvidia-container-runtime]",
"#debug = \"/nvidia-container-toolkit.log\"",
"",
"[nvidia-container-runtime.experimental]",
"debug = \"/nvidia-container-toolkit.experimental.log\"",
}
contents := []byte(strings.Join(lines, "\n"))
testDir := path.Join(wd, "test")
filename := path.Join(testDir, configFileRelativePath)
require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
require.NoError(t, ioutil.WriteFile(filename, contents, 0766))
defer func() { require.NoError(t, os.RemoveAll(testDir)) }()
logger, _ := testlog.NewNullLogger()
c := newConfigFromFileWithLogger(logger, filename)
cfg := getDefaultConfig()
err = c.Update(cfg)
require.NoError(t, err)
require.Equal(t, "/nvidia-container-toolkit.experimental.log", cfg.DebugFilePath)
require.Equal(t, "/run/nvidia/driver", cfg.Root)
}
func TestConfigFromNonexistentFileReturnsNoop(t *testing.T) {
logger, _ := testlog.NewNullLogger()
c := newConfigFromFileWithLogger(logger, "/does/not/exist")
n, ok := c.(*noop)
require.True(t, ok)
require.NotNil(t, n)
}
func TestGetConfigKey(t *testing.T) {
require.Equal(t, "key", tomlSection{}.configKey("key"))
require.Equal(t, "section.key", tomlSection{section: "section"}.configKey("key"))
}

View File

@@ -0,0 +1,144 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package main
import (
"fmt"
"os"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/ensure"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/filter"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/modify"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/runtime"
log "github.com/sirupsen/logrus"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-container-runtime.experimental/config"
)
const (
visibleDevicesEnvvar = "NVIDIA_VISIBLE_DEVICES"
visibleDevicesVoid = "void"
)
var logger = log.New()
func main() {
cfg, err := config.GetConfig(logger)
if err != nil {
logger.Errorf("Error loading config: %v", err)
os.Exit(1)
}
if cfg.DebugFilePath != "" && cfg.DebugFilePath != "/dev/nul" {
logFile, err := os.OpenFile(cfg.DebugFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
logger.Errorf("Error opening debug log file: %v", err)
os.Exit(1)
}
defer logFile.Close()
logger.SetOutput(logFile)
}
logLevel, err := log.ParseLevel(cfg.LogLevel)
if err == nil {
logger.SetLevel(logLevel)
} else {
logger.Warnf("Invalid log-level '%v'; using '%v'", cfg.LogLevel, logger.Level.String())
}
logger.Infof("Starting nvidia-container-runtime: %v", os.Args)
logger.Debugf("Using config=%+v", cfg)
if err := run(cfg, os.Args); err != nil {
logger.Errorf("Error running runtime: %v", err)
os.Exit(1)
}
}
func run(cfg *config.Config, args []string) error {
logger.Debugf("running with args=%v", args)
// We create a low-level runtime
lowLevelRuntime, err := createLowLevelRuntime(logger, cfg)
if err != nil {
return fmt.Errorf("error constructing low-level runtime: %v", err)
}
if !oci.HasCreateSubcommand(args) {
logger.Infof("No modification of OCI specification required")
logger.Infof("Forwarding command to runtime")
return lowLevelRuntime.Exec(args)
}
// We create the OCI spec that is to be modified
ociSpec, bundleDir, err := oci.NewSpecFromArgs(args)
if err != nil {
return fmt.Errorf("error constructing OCI spec: %v", err)
}
err = ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification: %v", err)
}
visibleDevices, exists := ociSpec.LookupEnv(visibleDevicesEnvvar)
if !exists || visibleDevices == "" || visibleDevices == visibleDevicesVoid {
logger.Infof("Using low-level runtime: %v=%v (exists=%v)", visibleDevicesEnvvar, visibleDevices, exists)
return lowLevelRuntime.Exec(os.Args)
}
// We create the modifier that will be applied by the Modifying Runtime Wrapper
modifier, err := createModifier(cfg.Root, bundleDir, visibleDevices, ociSpec)
if err != nil {
return fmt.Errorf("error constructing modifer: %v", err)
}
// We construct the Modifying runtime
r := runtime.NewModifyingRuntimeWrapperWithLogger(logger, lowLevelRuntime, ociSpec, modifier)
return r.Exec(os.Args)
}
func createLowLevelRuntime(logger *log.Logger, cfg *config.Config) (oci.Runtime, error) {
if cfg.RuntimePath == "" {
return oci.NewLowLevelRuntimeWithLogger(logger, "docker-runc", "runc")
}
logger.Infof("Creating runtime with path %v", cfg.RuntimePath)
return oci.NewRuntimeForPathWithLogger(logger, cfg.RuntimePath)
}
func createModifier(root string, bundleDir string, visibleDevices string, env filter.EnvLookup) (modify.Modifier, error) {
// We set up the modifier using discovery
discovered, err := discover.NewNVMLServerWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
// We apply a filter to the discovered devices
selected := filter.NewSelectDevicesFromWithLogger(logger, discovered, visibleDevices, env)
// We ensure that the selected devices are available
available := ensure.NewEnsureDevicesWithLogger(logger, selected, root)
// We construct the modifer for the OCI spec
modifier := modify.NewModifierWithLoggerFor(logger, available, root, bundleDir)
return modifier, nil
}

20
docker/Dockerfile.devel Normal file
View File

@@ -0,0 +1,20 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG GOLANG_VERSION=x.x.x
FROM golang:${GOLANG_VERSION}
RUN go get -u golang.org/x/lint/golint
RUN go get -u github.com/matryer/moq
RUN go get -u github.com/gordonklaus/ineffassign
RUN go get -u github.com/client9/misspell/cmd/misspell

56
examples/discover/main.go Normal file
View File

@@ -0,0 +1,56 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package main
import (
"encoding/json"
"os"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
log "github.com/sirupsen/logrus"
)
func main() {
log.Infof("Starting device discovery with NVML")
d, err := discover.NewNVMLServer("")
if err != nil {
log.Errorf("Error creating NVML Server: %v", err)
return
}
devices, err := d.Devices()
if err != nil {
log.Errorf("Error discovering devices: %v", err)
return
}
mounts, err := d.Mounts()
if err != nil {
log.Errorf("Error discovering mounts: %v", err)
return
}
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
log.Infof("Discovered devices:")
enc.Encode(devices)
log.Infof("Discovered libraries:")
enc.Encode(mounts)
}

65
examples/filter/main.go Normal file
View File

@@ -0,0 +1,65 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package main
import (
"encoding/json"
"os"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/filter"
log "github.com/sirupsen/logrus"
)
func main() {
d, err := discover.NewNVMLServer("")
if err != nil {
log.Errorf("Error discovering devices: %v", err)
}
selected := filter.NewSelectDevicesFrom(d, "all", nil)
devices, err := selected.Devices()
if err != nil {
log.Errorf("Error discovering devices: %v", err)
return
}
mounts, err := selected.Mounts()
if err != nil {
log.Errorf("Error discovering mounts: %v", err)
return
}
hooks, err := selected.Hooks()
if err != nil {
log.Errorf("Error discovering hooks: %v", err)
return
}
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
log.Infof("Discovered devices:")
enc.Encode(devices)
log.Infof("Discovered libraries:")
enc.Encode(mounts)
log.Infof("Discovered hook:")
enc.Encode(hooks)
}

View File

@@ -0,0 +1,26 @@
package main
import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache"
log "github.com/sirupsen/logrus"
)
var logger = log.StandardLogger()
func main() {
logger.SetLevel(log.DebugLevel)
logger.Infof("Starting device discovery with NVML")
cache, err := ldcache.NewLDCacheWithLogger(logger, "/run/nvidia/driver")
if err != nil {
logger.Errorf("Error loading ldcache: %v", err)
return
}
defer cache.Close()
libs32, libs64 := cache.Lookup("lib")
logger.Infof("32-bit: %v", libs32)
logger.Infof("64-bit: %v", libs64)
}

4
go.mod
View File

@@ -4,6 +4,10 @@ go 1.14
require (
github.com/BurntSushi/toml v0.3.1
github.com/NVIDIA/go-nvml v0.11.1-0
github.com/opencontainers/runtime-spec v1.0.2
github.com/pelletier/go-toml v1.9.3
github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.7.0
golang.org/x/mod v0.3.0
)

8
go.sum
View File

@@ -1,8 +1,15 @@
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/NVIDIA/go-nvml v0.11.1-0/go.mod h1:hy7HYeQy335x6nEss0Ne3PYqleRa6Ct+VKD9RQ4nyFs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/pelletier/go-toml v1.9.3 h1:zeC5b1GviRUyKYd6OJPvBU/mcVDVoL1OhT17FCt5dSQ=
github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@@ -13,6 +20,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

239
internal/ldcache/ldcache.go Normal file
View File

@@ -0,0 +1,239 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
// Adapted from https://github.com/rai-project/ldcache
package ldcache
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
)
const ldcachePath = "/etc/ld.so.cache"
const (
magicString1 = "ld.so-1.7.0"
magicString2 = "glibc-ld.so.cache"
magicVersion = "1.1"
)
const (
flagTypeMask = 0x00ff
flagTypeELF = 0x0001
flagArchMask = 0xff00
flagArchI386 = 0x0000
flagArchX8664 = 0x0300
flagArchX32 = 0x0800
flagArchPpc64le = 0x0500
)
var ErrInvalidCache = errors.New("invalid ld.so.cache file")
type Header1 struct {
Magic [len(magicString1) + 1]byte // include null delimiter
NLibs uint32
}
type Entry1 struct {
Flags int32
Key, Value uint32
}
type Header2 struct {
Magic [len(magicString2)]byte
Version [len(magicVersion)]byte
NLibs uint32
TableSize uint32
_ [3]uint32 // unused
_ uint64 // force 8 byte alignment
}
type Entry2 struct {
Flags int32
Key, Value uint32
OSVersion uint32
HWCap uint64
}
type LDCache struct {
*bytes.Reader
data, libs []byte
header Header2
entries []Entry2
root string
logger *log.Logger
}
func NewLDCacheWithLogger(logger *log.Logger, root string) (*LDCache, error) {
return openWithRoot(logger, root)
}
func Open() (*LDCache, error) {
return openWithRoot(log.StandardLogger(), "")
}
func openWithRoot(logger *log.Logger, root string) (*LDCache, error) {
path := filepath.Join(root, ldcachePath)
logger.Debugf("Opening ld.conf at %v", path)
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
d, err := syscall.Mmap(int(f.Fd()), 0, int(fi.Size()),
syscall.PROT_READ, syscall.MAP_PRIVATE)
if err != nil {
return nil, err
}
cache := &LDCache{
data: d,
Reader: bytes.NewReader(d),
root: root,
logger: logger,
}
return cache, cache.parse()
}
func (c *LDCache) Close() error {
return syscall.Munmap(c.data)
}
func (c *LDCache) Magic() string {
return string(c.header.Magic[:])
}
func (c *LDCache) Version() string {
return string(c.header.Version[:])
}
func strn(b []byte, n int) string {
return string(b[:n])
}
func (c *LDCache) parse() error {
var header Header1
// Check for the old format (< glibc-2.2)
if c.Len() <= int(unsafe.Sizeof(header)) {
return ErrInvalidCache
}
if strn(c.data, len(magicString1)) == magicString1 {
if err := binary.Read(c, binary.LittleEndian, &header); err != nil {
return err
}
n := int64(header.NLibs) * int64(unsafe.Sizeof(Entry1{}))
offset, err := c.Seek(n, 1) // skip old entries
if err != nil {
return err
}
n = (-offset) & int64(unsafe.Alignof(c.header)-1)
_, err = c.Seek(n, 1) // skip padding
if err != nil {
return err
}
}
c.libs = c.data[c.Size()-int64(c.Len()):] // kv offsets start here
if err := binary.Read(c, binary.LittleEndian, &c.header); err != nil {
return err
}
if c.Magic() != magicString2 || c.Version() != magicVersion {
return ErrInvalidCache
}
c.entries = make([]Entry2, c.header.NLibs)
if err := binary.Read(c, binary.LittleEndian, &c.entries); err != nil {
return err
}
return nil
}
func (c *LDCache) Lookup(libs ...string) (paths32, paths64 []string) {
c.logger.Debugf("Looking up %v in cache", libs)
type void struct{}
var paths *[]string
set := make(map[string]void)
prefix := make([][]byte, len(libs))
for i := range libs {
prefix[i] = []byte(libs[i])
}
for _, e := range c.entries {
if ((e.Flags & flagTypeMask) & flagTypeELF) == 0 {
continue
}
switch e.Flags & flagArchMask {
case flagArchX8664:
fallthrough
case flagArchPpc64le:
paths = &paths64
case flagArchX32:
fallthrough
case flagArchI386:
paths = &paths32
default:
continue
}
if e.Key > uint32(len(c.libs)) || e.Value > uint32(len(c.libs)) {
continue
}
lib := c.libs[e.Key:]
value := c.libs[e.Value:]
for _, p := range prefix {
if bytes.HasPrefix(lib, p) {
n := bytes.IndexByte(value, 0)
if n < 0 {
break
}
name := filepath.Join(c.root, strn(value, n))
c.logger.Debugf("checking %v", string(name))
path, err := filepath.EvalSymlinks(name)
if err != nil {
c.logger.Debugf("could not resolve symlink for %v", name)
break
}
if _, ok := set[path]; ok {
break
}
set[path] = void{}
*paths = append(*paths, path)
break
}
}
}
return
}

78
internal/lookup/file.go Normal file
View File

@@ -0,0 +1,78 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
import (
"fmt"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
)
type file struct {
logger *log.Logger
prefixes []string
filter func(string) error
}
func NewFileLocator(root string) Locator {
return NewFileLocatorWithLogger(log.StandardLogger(), root)
}
func NewFileLocatorWithLogger(logger *log.Logger, root string) Locator {
l := file{
logger: logger,
prefixes: []string{root},
filter: assertFile,
}
return &l
}
var _ Locator = (*file)(nil)
func (p file) Locate(filename string) ([]string, error) {
var filenames []string
for _, prefix := range p.prefixes {
candidate := filepath.Join(prefix, filename)
p.logger.Debugf("Checking candidate '%v'", candidate)
err := p.filter(candidate)
if err != nil {
p.logger.Debugf("Candidate '%v' does not meet requirements: %v", candidate, err)
continue
}
filenames = append(filenames, candidate)
}
if len(filename) == 0 {
return nil, fmt.Errorf("file %v not found", filename)
}
return filenames, nil
}
func assertFile(filename string) error {
info, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("error getting info for %v: %v", filename, err)
}
if info.IsDir() {
return fmt.Errorf("specified path '%v' is a directory", filename)
}
return nil
}

View File

@@ -0,0 +1,65 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache"
log "github.com/sirupsen/logrus"
)
type library struct {
logger *log.Logger
cache *ldcache.LDCache
}
var _ Locator = (*library)(nil)
// NewLibraryLocatorWithLogger creates a library locator using the standard logger
func NewLibraryLocator(root string) (Locator, error) {
return NewLibraryLocatorWithLogger(log.StandardLogger(), root)
}
// NewLibraryLocatorWithLogger creates a library locator using the specified logger.
func NewLibraryLocatorWithLogger(logger *log.Logger, root string) (Locator, error) {
logger.Infof("Reading ldcache at %v", root)
cache, err := ldcache.NewLDCacheWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error loading ldcache: %v", err)
}
l := library{
logger: logger,
cache: cache,
}
return &l, nil
}
func (l library) Locate(libname string) ([]string, error) {
paths32, paths64 := l.cache.Lookup(libname)
if len(paths32) > 0 {
l.logger.Warnf("Ignoring 32-bit libraries for %v: %v", libname, paths32)
}
if len(paths64) == 0 {
return nil, fmt.Errorf("64-bit library %v not found", libname)
}
return paths64, nil
}

24
internal/lookup/lookup.go Normal file
View File

@@ -0,0 +1,24 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
//go:generate moq -stub -out lookup_mock.go . Locator
// Locator defines the interface for locating files on a system.
type Locator interface {
Locate(string) ([]string, error)
}

View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package lookup
import (
"sync"
)
// Ensure, that LocatorMock does implement Locator.
// If this is not the case, regenerate this file with moq.
var _ Locator = &LocatorMock{}
// LocatorMock is a mock implementation of Locator.
//
// func TestSomethingThatUsesLocator(t *testing.T) {
//
// // make and configure a mocked Locator
// mockedLocator := &LocatorMock{
// LocateFunc: func(s string) ([]string, error) {
// panic("mock out the Locate method")
// },
// }
//
// // use mockedLocator in code that requires Locator
// // and then make assertions.
//
// }
type LocatorMock struct {
// LocateFunc mocks the Locate method.
LocateFunc func(s string) ([]string, error)
// calls tracks calls to the methods.
calls struct {
// Locate holds details about calls to the Locate method.
Locate []struct {
// S is the s argument value.
S string
}
}
lockLocate sync.RWMutex
}
// Locate calls LocateFunc.
func (mock *LocatorMock) Locate(s string) ([]string, error) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLocate.Lock()
mock.calls.Locate = append(mock.calls.Locate, callInfo)
mock.lockLocate.Unlock()
if mock.LocateFunc == nil {
var (
stringsOut []string
errOut error
)
return stringsOut, errOut
}
return mock.LocateFunc(s)
}
// LocateCalls gets all the calls that were made to Locate.
// Check the length with:
// len(mockedLocator.LocateCalls())
func (mock *LocatorMock) LocateCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLocate.RLock()
calls = mock.calls.Locate
mock.lockLocate.RUnlock()
return calls
}

94
internal/lookup/path.go Normal file
View File

@@ -0,0 +1,94 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
const (
envPath = "PATH"
)
var defaultPaths = []string{"/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"}
type path struct {
file
}
func NewPathLocator(root string) Locator {
return NewPathLocatorWithLogger(log.StandardLogger(), root)
}
func NewPathLocatorWithLogger(logger *log.Logger, root string) Locator {
pathEnv := os.Getenv(envPath)
paths := filepath.SplitList(pathEnv)
if root != "" {
paths = append(paths, defaultPaths...)
}
var prefixes []string
for _, dir := range paths {
prefixes = append(prefixes, filepath.Join(root, dir))
}
l := path{
file: file{
logger: logger,
prefixes: prefixes,
filter: assertExecutable,
},
}
return &l
}
var _ Locator = (*path)(nil)
func (p path) Locate(filename string) ([]string, error) {
// For absolute paths we ensure that it is executable
if strings.Contains(filename, "/") {
err := assertExecutable(filename)
if err != nil {
return nil, fmt.Errorf("absolute path %v is not an executable file: %v", filename, err)
}
return []string{filename}, nil
}
return p.file.Locate(filename)
}
func assertExecutable(filename string) error {
err := assertFile(filename)
if err != nil {
return err
}
info, err := os.Stat(filename)
if err != nil {
return err
}
if info.Mode()&0111 == 0 {
return fmt.Errorf("specified file '%v' is not executable", filename)
}
return nil
}

141
internal/nvcaps/nvcaps.go Normal file
View File

@@ -0,0 +1,141 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package nvcaps
import (
"bufio"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strconv"
"strings"
)
const (
nvidiaProcDriverPath = "/proc/driver/nvidia"
nvidiaCapabilitiesPath = nvidiaProcDriverPath + "/capabilities"
nvcapsProcDriverPath = "/proc/driver/nvidia-caps"
nvcapsMigMinorsPath = nvcapsProcDriverPath + "/mig-minors"
nvcapsDevicePath = "/dev/nvidia-caps"
)
// MigMinor represents the minor number of a MIG device
type MigMinor int
// MigCap represents the path to a MIG cap file
type MigCap string
// LoadMigMinors loads the MIG minors file and returns its contents as a map
func LoadMigMinors() (map[MigCap]MigMinor, error) {
// Open nvcapsMigMinorsPath for walking.
// If the nvcapsMigMinorsPath does not exist, then we are not on a MIG
// capable machine, so there is nothing to do.
// The format of this file is discussed in:
// https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#unique_1576522674
minorsFile, err := os.Open(nvcapsMigMinorsPath)
if os.IsNotExist(err) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("error opening MIG minors file: %v", err)
}
defer minorsFile.Close()
return processMinorsFile(minorsFile), nil
}
func processMinorsFile(minorsFile io.Reader) map[MigCap]MigMinor {
// Walk each line of nvcapsMigMinorsPath and construct a mapping of nvidia
// capabilities path to device minor for that capability
migCaps := make(map[MigCap]MigMinor)
scanner := bufio.NewScanner(minorsFile)
for scanner.Scan() {
cap, minor, err := processMigMinorsLine(scanner.Text())
if err != nil {
log.Printf("Skipping line in MIG minors file: %v", err)
continue
}
migCaps[cap] = minor
}
return migCaps
}
func processMigMinorsLine(line string) (MigCap, MigMinor, error) {
parts := strings.Split(line, " ")
if len(parts) != 2 {
return "", 0, fmt.Errorf("error processing line: %v", line)
}
migCap := MigCap(parts[0])
if !migCap.isValid() {
return "", 0, fmt.Errorf("invalid MIG minors line: '%v'", line)
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return "", 0, fmt.Errorf("error reading MIG minor from '%v': %v", line, err)
}
return migCap, MigMinor(minor), nil
}
func (m MigCap) isValid() bool {
cap := string(m)
switch cap {
case "config", "monitor":
return true
default:
var gpu int
var gi int
var ci int
// Loog for a CI access file
n, _ := fmt.Sscanf(cap, "gpu%d/gi%d/ci%d/access", &gpu, &gi, &ci)
if n == 3 {
return true
}
// Look for a GI access file
n, _ = fmt.Sscanf(cap, "gpu%d/gi%d/access %d", &gpu, &gi)
if n == 2 {
return true
}
}
return false
}
// ProcPath returns the proc path associated with the MIG capability
func (m MigCap) ProcPath() string {
id := string(m)
var path string
switch id {
case "config", "monitor":
path = "mig/" + id
default:
parts := strings.SplitN(id, "/", 2)
path = strings.Join([]string{parts[0], "mig", parts[1]}, "/")
}
return filepath.Join(nvidiaCapabilitiesPath, path)
}
// DevicePath returns the path for the nvidia-caps device with the specified
// minor number
func (m MigMinor) DevicePath() string {
return fmt.Sprintf(nvcapsDevicePath+"/nvidia-cap%d", m)
}

View File

@@ -0,0 +1,94 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package nvcaps
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestProcessMinorsFile(t *testing.T) {
testCases := []struct {
lines []string
expected map[MigCap]MigMinor
}{
{[]string{}, map[MigCap]MigMinor{}},
{[]string{"invalidLine"}, map[MigCap]MigMinor{}},
{[]string{"config 1"}, map[MigCap]MigMinor{"config": 1}},
{[]string{"gpu0/gi0/ci0/access 4"}, map[MigCap]MigMinor{"gpu0/gi0/ci0/access": 4}},
{[]string{"config 1", "invalidLine"}, map[MigCap]MigMinor{"config": 1}},
{[]string{"config 1", "gpu0/gi0/ci0/access 4"}, map[MigCap]MigMinor{"config": 1, "gpu0/gi0/ci0/access": 4}},
}
for _, tc := range testCases {
contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := processMinorsFile(contents)
require.Equalf(t, tc.expected, d, "testCase: %v", tc)
}
}
func TestProcessMigMinorsLine(t *testing.T) {
testCases := []struct {
line string
cap MigCap
minor MigMinor
err bool
}{
{"config 1", "config", 1, false},
{"monitor 2", "monitor", 2, false},
{"gpu0/gi0/access 3", "gpu0/gi0/access", 3, false},
{"gpu0/gi0/ci0/access 4", "gpu0/gi0/ci0/access", 4, false},
{"notconfig 99", "", 0, true},
{"config notanint", "", 0, true},
{"", "", 0, true},
}
for _, tc := range testCases {
cap, minor, err := processMigMinorsLine(tc.line)
require.Equalf(t, tc.cap, cap, "testCase: %v", tc)
require.Equalf(t, tc.minor, minor, "testCase: %v", tc)
if tc.err {
require.Errorf(t, err, "testCase: %v", tc)
} else {
require.NoErrorf(t, err, "testCase: %v", tc)
}
}
}
func TestMigCapProcPaths(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"config", "/proc/driver/nvidia/capabilities/mig/config"},
{"monitor", "/proc/driver/nvidia/capabilities/mig/monitor"},
{"gpu0/gi0/access", "/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"},
{"gpu0/gi0/ci0/access", "/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"},
}
for _, tc := range testCases {
m := MigCap(tc.input)
require.Equal(t, tc.expected, m.ProcPath())
}
}
func TestMigMinorDevicePath(t *testing.T) {
m := MigMinor(0)
require.Equal(t, "/dev/nvidia-caps/nvidia-cap0", m.DevicePath())
}

79
internal/nvml/consts.go Normal file
View File

@@ -0,0 +1,79 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
const (
SUCCESS = nvml.SUCCESS
ERROR_UNINITIALIZED = nvml.ERROR_UNINITIALIZED
ERROR_INVALID_ARGUMENT = nvml.ERROR_INVALID_ARGUMENT
ERROR_NOT_SUPPORTED = nvml.ERROR_NOT_SUPPORTED
ERROR_NO_PERMISSION = nvml.ERROR_NO_PERMISSION
ERROR_ALREADY_INITIALIZED = nvml.ERROR_ALREADY_INITIALIZED
ERROR_NOT_FOUND = nvml.ERROR_NOT_FOUND
ERROR_INSUFFICIENT_SIZE = nvml.ERROR_INSUFFICIENT_SIZE
ERROR_INSUFFICIENT_POWER = nvml.ERROR_INSUFFICIENT_POWER
ERROR_DRIVER_NOT_LOADED = nvml.ERROR_DRIVER_NOT_LOADED
ERROR_TIMEOUT = nvml.ERROR_TIMEOUT
ERROR_IRQ_ISSUE = nvml.ERROR_IRQ_ISSUE
ERROR_LIBRARY_NOT_FOUND = nvml.ERROR_LIBRARY_NOT_FOUND
ERROR_FUNCTION_NOT_FOUND = nvml.ERROR_FUNCTION_NOT_FOUND
ERROR_CORRUPTED_INFOROM = nvml.ERROR_CORRUPTED_INFOROM
ERROR_GPU_IS_LOST = nvml.ERROR_GPU_IS_LOST
ERROR_RESET_REQUIRED = nvml.ERROR_RESET_REQUIRED
ERROR_OPERATING_SYSTEM = nvml.ERROR_OPERATING_SYSTEM
ERROR_LIB_RM_VERSION_MISMATCH = nvml.ERROR_LIB_RM_VERSION_MISMATCH
ERROR_IN_USE = nvml.ERROR_IN_USE
ERROR_MEMORY = nvml.ERROR_MEMORY
ERROR_NO_DATA = nvml.ERROR_NO_DATA
ERROR_VGPU_ECC_NOT_SUPPORTED = nvml.ERROR_VGPU_ECC_NOT_SUPPORTED
ERROR_INSUFFICIENT_RESOURCES = nvml.ERROR_INSUFFICIENT_RESOURCES
ERROR_UNKNOWN = nvml.ERROR_UNKNOWN
)
const (
DEVICE_MIG_ENABLE = nvml.DEVICE_MIG_ENABLE
DEVICE_MIG_DISABLE = nvml.DEVICE_MIG_DISABLE
)
const (
GPU_INSTANCE_PROFILE_1_SLICE = nvml.GPU_INSTANCE_PROFILE_1_SLICE
GPU_INSTANCE_PROFILE_2_SLICE = nvml.GPU_INSTANCE_PROFILE_2_SLICE
GPU_INSTANCE_PROFILE_3_SLICE = nvml.GPU_INSTANCE_PROFILE_3_SLICE
GPU_INSTANCE_PROFILE_4_SLICE = nvml.GPU_INSTANCE_PROFILE_4_SLICE
GPU_INSTANCE_PROFILE_7_SLICE = nvml.GPU_INSTANCE_PROFILE_7_SLICE
GPU_INSTANCE_PROFILE_8_SLICE = nvml.GPU_INSTANCE_PROFILE_8_SLICE
GPU_INSTANCE_PROFILE_COUNT = nvml.GPU_INSTANCE_PROFILE_COUNT
)
const (
COMPUTE_INSTANCE_PROFILE_1_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE
COMPUTE_INSTANCE_PROFILE_2_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_2_SLICE
COMPUTE_INSTANCE_PROFILE_3_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_3_SLICE
COMPUTE_INSTANCE_PROFILE_4_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_4_SLICE
COMPUTE_INSTANCE_PROFILE_7_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_7_SLICE
COMPUTE_INSTANCE_PROFILE_8_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_8_SLICE
COMPUTE_INSTANCE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_PROFILE_COUNT
)
const (
COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED
COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT
)

594
internal/nvml/mock.go Normal file
View File

@@ -0,0 +1,594 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import "fmt"
type MockServer struct {
Devices []Device
}
type MockLunaServer struct {
MockServer
}
type MockA100Device struct {
Index int
MinorNumber int
MigMode int
GpuInstances map[*MockA100GpuInstance]struct{}
GpuInstanceCounter uint32
}
type MockA100GpuInstance struct {
Info GpuInstanceInfo
ComputeInstances map[*MockA100ComputeInstance]struct{}
ComputeInstanceCounter uint32
}
type MockA100ComputeInstance struct {
Info ComputeInstanceInfo
}
var _ Interface = (*MockLunaServer)(nil)
var _ Device = (*MockA100Device)(nil)
var _ GpuInstance = (*MockA100GpuInstance)(nil)
var _ ComputeInstance = (*MockA100ComputeInstance)(nil)
var MockA100MIGProfiles = struct {
GpuInstanceProfiles map[int]GpuInstanceProfileInfo
ComputeInstanceProfiles map[int]map[int]ComputeInstanceProfileInfo
}{
GpuInstanceProfiles: map[int]GpuInstanceProfileInfo{
GPU_INSTANCE_PROFILE_1_SLICE: {
Id: GPU_INSTANCE_PROFILE_1_SLICE,
IsP2pSupported: 0,
SliceCount: 1,
InstanceCount: 7,
MultiprocessorCount: 1,
CopyEngineCount: 1,
DecoderCount: 0,
EncoderCount: 0,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 5120,
},
GPU_INSTANCE_PROFILE_2_SLICE: {
Id: GPU_INSTANCE_PROFILE_2_SLICE,
IsP2pSupported: 0,
SliceCount: 2,
InstanceCount: 3,
MultiprocessorCount: 2,
CopyEngineCount: 2,
DecoderCount: 1,
EncoderCount: 1,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 10240,
},
GPU_INSTANCE_PROFILE_3_SLICE: {
Id: GPU_INSTANCE_PROFILE_3_SLICE,
IsP2pSupported: 0,
SliceCount: 3,
InstanceCount: 2,
MultiprocessorCount: 3,
CopyEngineCount: 4,
DecoderCount: 2,
EncoderCount: 2,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 20480,
},
GPU_INSTANCE_PROFILE_4_SLICE: {
Id: GPU_INSTANCE_PROFILE_4_SLICE,
IsP2pSupported: 0,
SliceCount: 4,
InstanceCount: 1,
MultiprocessorCount: 4,
CopyEngineCount: 4,
DecoderCount: 2,
EncoderCount: 2,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 20480,
},
GPU_INSTANCE_PROFILE_7_SLICE: {
Id: GPU_INSTANCE_PROFILE_7_SLICE,
IsP2pSupported: 0,
SliceCount: 7,
InstanceCount: 1,
MultiprocessorCount: 7,
CopyEngineCount: 8,
DecoderCount: 5,
EncoderCount: 5,
JpegCount: 1,
OfaCount: 1,
MemorySizeMB: 40960,
},
},
ComputeInstanceProfiles: map[int]map[int]ComputeInstanceProfileInfo{
GPU_INSTANCE_PROFILE_1_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 1,
MultiprocessorCount: 1,
SharedCopyEngineCount: 1,
SharedDecoderCount: 0,
SharedEncoderCount: 0,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_2_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 2,
MultiprocessorCount: 1,
SharedCopyEngineCount: 2,
SharedDecoderCount: 1,
SharedEncoderCount: 1,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 1,
MultiprocessorCount: 2,
SharedCopyEngineCount: 2,
SharedDecoderCount: 1,
SharedEncoderCount: 1,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_3_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 3,
MultiprocessorCount: 1,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 1,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 1,
MultiprocessorCount: 2,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_3_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_3_SLICE,
SliceCount: 3,
InstanceCount: 1,
MultiprocessorCount: 3,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 0,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_4_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 4,
MultiprocessorCount: 1,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 2,
MultiprocessorCount: 2,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_4_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_4_SLICE,
SliceCount: 4,
InstanceCount: 1,
MultiprocessorCount: 4,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_7_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 7,
MultiprocessorCount: 1,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 3,
MultiprocessorCount: 2,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_3_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_3_SLICE,
SliceCount: 3,
InstanceCount: 2,
MultiprocessorCount: 3,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_4_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_4_SLICE,
SliceCount: 4,
InstanceCount: 1,
MultiprocessorCount: 4,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_7_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_7_SLICE,
SliceCount: 7,
InstanceCount: 1,
MultiprocessorCount: 7,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
},
},
}
func NewMockNVMLServer(devices ...Device) Interface {
return &MockServer{
Devices: devices,
}
}
func NewMockNVMLOnLunaServer() Interface {
devices := []Device{
NewMockA100Device(0),
NewMockA100Device(1),
NewMockA100Device(2),
NewMockA100Device(3),
NewMockA100Device(4),
NewMockA100Device(5),
NewMockA100Device(6),
NewMockA100Device(7),
}
return NewMockNVMLServer(devices...)
}
func NewMockA100Device(index int) Device {
return &MockA100Device{
Index: index,
GpuInstances: make(map[*MockA100GpuInstance]struct{}),
GpuInstanceCounter: 0,
}
}
func NewMockA100GpuInstance(info GpuInstanceInfo) GpuInstance {
return &MockA100GpuInstance{
Info: info,
ComputeInstances: make(map[*MockA100ComputeInstance]struct{}),
ComputeInstanceCounter: 0,
}
}
func NewMockA100ComputeInstance(info ComputeInstanceInfo) ComputeInstance {
return &MockA100ComputeInstance{
Info: info,
}
}
func (n *MockServer) Init() Return {
return MockReturn(SUCCESS)
}
func (n *MockServer) Shutdown() Return {
return MockReturn(SUCCESS)
}
func (n *MockServer) DeviceGetCount() (int, Return) {
return len(n.Devices), MockReturn(SUCCESS)
}
func (n *MockServer) DeviceGetHandleByIndex(index int) (Device, Return) {
if index < 0 || index >= len(n.Devices) {
return nil, MockReturn(ERROR_INVALID_ARGUMENT)
}
return n.Devices[index], MockReturn(SUCCESS)
}
func (n *MockServer) SystemGetDriverVersion() (string, Return) {
return "999.99", MockReturn(SUCCESS)
}
func (d *MockA100Device) GetIndex() (int, Return) {
return d.Index, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetPciInfo() (PciInfo, Return) {
var busID [32]int8
for i, b := range []byte("0000FFFF:FF:FF.F") {
busID[i] = int8(b)
}
p := PciInfo{
BusId: busID,
PciDeviceId: 0x20B010DE,
}
return p, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetUUID() (string, Return) {
return fmt.Sprintf("GPU-%d", d.Index), MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMinorNumber() (int, Return) {
return d.MinorNumber, MockReturn(SUCCESS)
}
func (d *MockA100Device) SetMigMode(mode int) (Return, Return) {
d.MigMode = mode
return MockReturn(SUCCESS), MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMigMode() (int, int, Return) {
return d.MigMode, d.MigMode, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetGpuInstanceProfileInfo(giProfileId int) (GpuInstanceProfileInfo, Return) {
if giProfileId < 0 || giProfileId >= GPU_INSTANCE_PROFILE_COUNT {
return GpuInstanceProfileInfo{}, MockReturn(ERROR_INVALID_ARGUMENT)
}
if _, exists := MockA100MIGProfiles.GpuInstanceProfiles[giProfileId]; !exists {
return GpuInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
return MockA100MIGProfiles.GpuInstanceProfiles[giProfileId], MockReturn(SUCCESS)
}
func (d *MockA100Device) CreateGpuInstance(info *GpuInstanceProfileInfo) (GpuInstance, Return) {
giInfo := GpuInstanceInfo{
Device: d,
Id: d.GpuInstanceCounter,
ProfileId: info.Id,
}
d.GpuInstanceCounter++
gi := NewMockA100GpuInstance(giInfo)
d.GpuInstances[gi.(*MockA100GpuInstance)] = struct{}{}
return gi, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
var gis []GpuInstance
for gi := range d.GpuInstances {
if gi.Info.ProfileId == info.Id {
gis = append(gis, gi)
}
}
return gis, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMaxMigDeviceCount() (int, Return) {
var count int
for gi := range d.GpuInstances {
count = count + int(gi.ComputeInstanceCounter)
}
return count, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMigDeviceHandleByIndex(Index int) (Device, Return) {
var count int
for gi := range d.GpuInstances {
if count+int(gi.ComputeInstanceCounter) < Index {
count = count + int(gi.ComputeInstanceCounter)
continue
}
for ci := range gi.ComputeInstances {
if count < Index {
count++
continue
}
return ci, MockReturn(SUCCESS)
}
}
return nil, MockReturn(ERROR_NOT_FOUND)
}
func (d *MockA100Device) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
return nil, MockReturn(ERROR_NOT_SUPPORTED)
}
func (d *MockA100Device) IsMigDeviceHandle() (bool, Return) {
return false, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetComputeInstanceId() (int, Return) {
panic("Not implemented: GetComputeInstanceId")
}
func (d *MockA100Device) GetGPUInstanceId() (int, Return) {
panic("Not implemented: GetGPUInstanceId")
}
func (gi *MockA100GpuInstance) GetInfo() (GpuInstanceInfo, Return) {
return gi.Info, MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) GetComputeInstanceProfileInfo(ciProfileId int, ciEngProfileId int) (ComputeInstanceProfileInfo, Return) {
if ciProfileId < 0 || ciProfileId >= COMPUTE_INSTANCE_PROFILE_COUNT {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_INVALID_ARGUMENT)
}
if ciEngProfileId != COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
giProfileId := int(gi.Info.ProfileId)
if _, exists := MockA100MIGProfiles.ComputeInstanceProfiles[giProfileId]; !exists {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
if _, exists := MockA100MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId]; !exists {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
return MockA100MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId], MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) CreateComputeInstance(info *ComputeInstanceProfileInfo) (ComputeInstance, Return) {
ciInfo := ComputeInstanceInfo{
Device: gi.Info.Device,
GpuInstance: gi,
Id: gi.ComputeInstanceCounter,
ProfileId: info.Id,
}
gi.ComputeInstanceCounter++
ci := NewMockA100ComputeInstance(ciInfo)
gi.ComputeInstances[ci.(*MockA100ComputeInstance)] = struct{}{}
return ci, MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) GetComputeInstances(info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) {
var cis []ComputeInstance
for ci := range gi.ComputeInstances {
if ci.Info.ProfileId == info.Id {
cis = append(cis, ci)
}
}
return cis, MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) Destroy() Return {
delete(gi.Info.Device.(*MockA100Device).GpuInstances, gi)
return MockReturn(SUCCESS)
}
func (ci *MockA100ComputeInstance) GetInfo() (ComputeInstanceInfo, Return) {
return ci.Info, MockReturn(SUCCESS)
}
func (ci *MockA100ComputeInstance) Destroy() Return {
delete(ci.Info.GpuInstance.(*MockA100GpuInstance).ComputeInstances, ci)
return MockReturn(SUCCESS)
}
// Since a compute instance can be used as a MIG device handle, it must also
// implement the Device interface
var _ Device = (*MockA100ComputeInstance)(nil)
func (c *MockA100ComputeInstance) GetIndex() (int, Return) {
return int(c.Info.Id), MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetPciInfo() (PciInfo, Return) {
// TODO: How does this behave on an actual MIG system?
panic("Not implemented: GetPciInfo")
}
func (c *MockA100ComputeInstance) GetUUID() (string, Return) {
return fmt.Sprintf("MIG-%d", c.Info.Id), MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetMinorNumber() (int, Return) {
// TODO: This depends on the content of the mig-minors file and the (gpu, gi, ci) tuple
panic("Not implemented: GetMinorNumber")
}
func (c *MockA100ComputeInstance) SetMigMode(Mode int) (Return, Return) {
panic("Not implemented: SetMigMode")
}
func (c *MockA100ComputeInstance) GetMigMode() (int, int, Return) {
panic("Not implemented: GetMigMode")
}
func (c *MockA100ComputeInstance) GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return) {
panic("Not implemented: GetGpuInstanceProfileInfo")
}
func (c *MockA100ComputeInstance) CreateGpuInstance(Info *GpuInstanceProfileInfo) (GpuInstance, Return) {
panic("Not implemented: CreateGpuInstance")
}
func (c *MockA100ComputeInstance) GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
panic("Not implemented: GetGpuInstances")
}
func (c *MockA100ComputeInstance) GetMaxMigDeviceCount() (int, Return) {
panic("Not implemented: GetMaxMigDeviceCount")
}
func (c *MockA100ComputeInstance) GetMigDeviceHandleByIndex(Index int) (Device, Return) {
panic("Not implemented: GetMigDeviceHandleByIndex")
}
func (c *MockA100ComputeInstance) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
return c.Info.Device, MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) IsMigDeviceHandle() (bool, Return) {
return true, MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetComputeInstanceId() (int, Return) {
return int(c.Info.Id), MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetGPUInstanceId() (int, Return) {
info, r := c.Info.GpuInstance.GetInfo()
if r.Value() != SUCCESS {
return 0, MockReturn(r.Value())
}
return int(info.Id), MockReturn(SUCCESS)
}

188
internal/nvml/nvml.go Normal file
View File

@@ -0,0 +1,188 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type nvmlLib struct{}
type nvmlDevice nvml.Device
type nvmlGpuInstance nvml.GpuInstance
type nvmlComputeInstance nvml.ComputeInstance
var _ Interface = (*nvmlLib)(nil)
var _ Device = (*nvmlDevice)(nil)
var _ GpuInstance = (*nvmlGpuInstance)(nil)
var _ ComputeInstance = (*nvmlComputeInstance)(nil)
func New() Interface {
return &nvmlLib{}
}
func (n *nvmlLib) Init() Return {
return nvmlReturn(nvml.Init())
}
func (n *nvmlLib) Shutdown() Return {
return nvmlReturn(nvml.Shutdown())
}
func (n *nvmlLib) DeviceGetCount() (int, Return) {
c, r := nvml.DeviceGetCount()
return c, nvmlReturn(r)
}
func (n *nvmlLib) DeviceGetHandleByIndex(index int) (Device, Return) {
d, r := nvml.DeviceGetHandleByIndex(index)
return nvmlDevice(d), nvmlReturn(r)
}
func (n *nvmlLib) SystemGetDriverVersion() (string, Return) {
v, r := nvml.SystemGetDriverVersion()
return v, nvmlReturn(r)
}
func (d nvmlDevice) GetIndex() (int, Return) {
i, r := nvml.Device(d).GetIndex()
return i, nvmlReturn(r)
}
func (d nvmlDevice) GetPciInfo() (PciInfo, Return) {
p, r := nvml.Device(d).GetPciInfo()
return PciInfo(p), nvmlReturn(r)
}
func (d nvmlDevice) GetUUID() (string, Return) {
u, r := nvml.Device(d).GetUUID()
return u, nvmlReturn(r)
}
func (d nvmlDevice) GetMinorNumber() (int, Return) {
m, r := nvml.Device(d).GetMinorNumber()
return m, nvmlReturn(r)
}
func (d nvmlDevice) IsMigDeviceHandle() (bool, Return) {
b, r := nvml.Device(d).IsMigDeviceHandle()
return b, nvmlReturn(r)
}
func (d nvmlDevice) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
p, r := nvml.Device(d).GetDeviceHandleFromMigDeviceHandle()
return nvmlDevice(p), nvmlReturn(r)
}
func (d nvmlDevice) GetGPUInstanceId() (int, Return) {
gi, r := nvml.Device(d).GetGpuInstanceId()
return gi, nvmlReturn(r)
}
func (d nvmlDevice) GetComputeInstanceId() (int, Return) {
ci, r := nvml.Device(d).GetComputeInstanceId()
return ci, nvmlReturn(r)
}
func (d nvmlDevice) SetMigMode(mode int) (Return, Return) {
r1, r2 := nvml.Device(d).SetMigMode(mode)
return nvmlReturn(r1), nvmlReturn(r2)
}
func (d nvmlDevice) GetMigMode() (int, int, Return) {
s1, s2, r := nvml.Device(d).GetMigMode()
return s1, s2, nvmlReturn(r)
}
func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileInfo, Return) {
p, r := nvml.Device(d).GetGpuInstanceProfileInfo(profile)
return GpuInstanceProfileInfo(p), nvmlReturn(r)
}
func (d nvmlDevice) CreateGpuInstance(info *GpuInstanceProfileInfo) (GpuInstance, Return) {
gi, r := nvml.Device(d).CreateGpuInstance((*nvml.GpuInstanceProfileInfo)(info))
return nvmlGpuInstance(gi), nvmlReturn(r)
}
func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
nvmlGis, r := nvml.Device(d).GetGpuInstances((*nvml.GpuInstanceProfileInfo)(info))
var gis []GpuInstance
for _, gi := range nvmlGis {
gis = append(gis, nvmlGpuInstance(gi))
}
return gis, nvmlReturn(r)
}
func (d nvmlDevice) GetMaxMigDeviceCount() (int, Return) {
m, r := nvml.Device(d).GetMaxMigDeviceCount()
return m, nvmlReturn(r)
}
func (d nvmlDevice) GetMigDeviceHandleByIndex(Index int) (Device, Return) {
h, r := nvml.Device(d).GetMigDeviceHandleByIndex(Index)
return nvmlDevice(h), nvmlReturn(r)
}
func (gi nvmlGpuInstance) GetInfo() (GpuInstanceInfo, Return) {
i, r := nvml.GpuInstance(gi).GetInfo()
info := GpuInstanceInfo{
Device: nvmlDevice(i.Device),
Id: i.Id,
ProfileId: i.ProfileId,
Placement: i.Placement,
}
return info, nvmlReturn(r)
}
func (gi nvmlGpuInstance) GetComputeInstanceProfileInfo(profile int, engProfile int) (ComputeInstanceProfileInfo, Return) {
p, r := nvml.GpuInstance(gi).GetComputeInstanceProfileInfo(profile, engProfile)
return ComputeInstanceProfileInfo(p), nvmlReturn(r)
}
func (gi nvmlGpuInstance) CreateComputeInstance(info *ComputeInstanceProfileInfo) (ComputeInstance, Return) {
ci, r := nvml.GpuInstance(gi).CreateComputeInstance((*nvml.ComputeInstanceProfileInfo)(info))
return nvmlComputeInstance(ci), nvmlReturn(r)
}
func (gi nvmlGpuInstance) GetComputeInstances(info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) {
nvmlCis, r := nvml.GpuInstance(gi).GetComputeInstances((*nvml.ComputeInstanceProfileInfo)(info))
var cis []ComputeInstance
for _, ci := range nvmlCis {
cis = append(cis, nvmlComputeInstance(ci))
}
return cis, nvmlReturn(r)
}
func (gi nvmlGpuInstance) Destroy() Return {
r := nvml.GpuInstance(gi).Destroy()
return nvmlReturn(r)
}
func (ci nvmlComputeInstance) GetInfo() (ComputeInstanceInfo, Return) {
i, r := nvml.ComputeInstance(ci).GetInfo()
info := ComputeInstanceInfo{
Device: nvmlDevice(i.Device),
GpuInstance: nvmlGpuInstance(i.GpuInstance),
Id: i.Id,
ProfileId: i.ProfileId,
}
return info, nvmlReturn(r)
}
func (ci nvmlComputeInstance) Destroy() Return {
r := nvml.ComputeInstance(ci).Destroy()
return nvmlReturn(r)
}

110
internal/nvml/return.go Normal file
View File

@@ -0,0 +1,110 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type Return interface {
Value() nvml.Return
String() string
Error() string
}
type nvmlReturn nvml.Return
type MockReturn nvml.Return
var _ Return = (*nvmlReturn)(nil)
var _ Return = (*MockReturn)(nil)
func (r nvmlReturn) Value() nvml.Return {
return nvml.Return(r)
}
func (r nvmlReturn) String() string {
return r.Error()
}
func (r nvmlReturn) Error() string {
return nvml.ErrorString(nvml.Return(r))
}
func (r MockReturn) Value() nvml.Return {
return nvml.Return(r)
}
func (r MockReturn) String() string {
return r.Error()
}
func (r MockReturn) Error() string {
switch nvml.Return(r) {
case SUCCESS:
return "SUCCESS"
case ERROR_UNINITIALIZED:
return "ERROR_UNINITIALIZED"
case ERROR_INVALID_ARGUMENT:
return "ERROR_INVALID_ARGUMENT"
case ERROR_NOT_SUPPORTED:
return "ERROR_NOT_SUPPORTED"
case ERROR_NO_PERMISSION:
return "ERROR_NO_PERMISSION"
case ERROR_ALREADY_INITIALIZED:
return "ERROR_ALREADY_INITIALIZED"
case ERROR_NOT_FOUND:
return "ERROR_NOT_FOUND"
case ERROR_INSUFFICIENT_SIZE:
return "ERROR_INSUFFICIENT_SIZE"
case ERROR_INSUFFICIENT_POWER:
return "ERROR_INSUFFICIENT_POWER"
case ERROR_DRIVER_NOT_LOADED:
return "ERROR_DRIVER_NOT_LOADED"
case ERROR_TIMEOUT:
return "ERROR_TIMEOUT"
case ERROR_IRQ_ISSUE:
return "ERROR_IRQ_ISSUE"
case ERROR_LIBRARY_NOT_FOUND:
return "ERROR_LIBRARY_NOT_FOUND"
case ERROR_FUNCTION_NOT_FOUND:
return "ERROR_FUNCTION_NOT_FOUND"
case ERROR_CORRUPTED_INFOROM:
return "ERROR_CORRUPTED_INFOROM"
case ERROR_GPU_IS_LOST:
return "ERROR_GPU_IS_LOST"
case ERROR_RESET_REQUIRED:
return "ERROR_RESET_REQUIRED"
case ERROR_OPERATING_SYSTEM:
return "ERROR_OPERATING_SYSTEM"
case ERROR_LIB_RM_VERSION_MISMATCH:
return "ERROR_LIB_RM_VERSION_MISMATCH"
case ERROR_IN_USE:
return "ERROR_IN_USE"
case ERROR_MEMORY:
return "ERROR_MEMORY"
case ERROR_NO_DATA:
return "ERROR_NO_DATA"
case ERROR_VGPU_ECC_NOT_SUPPORTED:
return "ERROR_VGPU_ECC_NOT_SUPPORTED"
case ERROR_INSUFFICIENT_RESOURCES:
return "ERROR_INSUFFICIENT_RESOURCES"
case ERROR_UNKNOWN:
return "ERROR_UNKNOWN"
default:
return "Unknown return value"
}
}

78
internal/nvml/types.go Normal file
View File

@@ -0,0 +1,78 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type Interface interface {
Init() Return
Shutdown() Return
DeviceGetCount() (int, Return)
DeviceGetHandleByIndex(Index int) (Device, Return)
SystemGetDriverVersion() (string, Return)
}
type Device interface {
GetIndex() (int, Return)
GetPciInfo() (PciInfo, Return)
GetUUID() (string, Return)
GetMinorNumber() (int, Return)
IsMigDeviceHandle() (bool, Return)
GetDeviceHandleFromMigDeviceHandle() (Device, Return)
SetMigMode(Mode int) (Return, Return)
GetMigMode() (int, int, Return)
GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return)
CreateGpuInstance(Info *GpuInstanceProfileInfo) (GpuInstance, Return)
GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return)
GetMaxMigDeviceCount() (int, Return)
GetMigDeviceHandleByIndex(Index int) (Device, Return)
GetGPUInstanceId() (int, Return)
GetComputeInstanceId() (int, Return)
}
type GpuInstance interface {
GetInfo() (GpuInstanceInfo, Return)
GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return)
CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return)
GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return)
Destroy() Return
}
type ComputeInstance interface {
GetInfo() (ComputeInstanceInfo, Return)
Destroy() Return
}
type GpuInstanceInfo struct {
Device Device
Id uint32
ProfileId uint32
Placement nvml.GpuInstancePlacement
}
type ComputeInstanceInfo struct {
Device Device
GpuInstance GpuInstance
Id uint32
ProfileId uint32
}
type PciInfo nvml.PciInfo
type GpuInstanceProfileInfo nvml.GpuInstanceProfileInfo
type ComputeInstanceProfileInfo nvml.ComputeInstanceProfileInfo

116
internal/proc/devices.go Normal file
View File

@@ -0,0 +1,116 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package proc
import (
"bufio"
"fmt"
"io"
"log"
"os"
"strings"
)
const (
procDevicesPath = "/proc/devices"
nvidiaDevicePrefix = "nvidia"
)
// Device represents a device as specified under /proc/devices
type Device struct {
Name string
Major int
}
// NvidiaDevices represents the set of nvidia owned devices under /proc/devices
type NvidiaDevices interface {
Exists(name string) bool
Get(name string) (Device, bool)
}
type nvidiaDevices map[string]Device
var _ NvidiaDevices = nvidiaDevices(nil)
// Exists checks if a Device with a given name exists or not
func (d nvidiaDevices) Exists(name string) bool {
_, exists := d[name]
return exists
}
// Get a Device from NvidiaDevices
func (d nvidiaDevices) Get(name string) (Device, bool) {
device, exists := d[name]
return device, exists
}
func (d nvidiaDevices) add(devices ...Device) {
for _, device := range devices {
d[device.Name] = device
}
}
// NewMockNvidiaDevices returns NvidiaDevices populated from the devices passed in
func NewMockNvidiaDevices(devices ...Device) NvidiaDevices {
nvds := make(nvidiaDevices)
nvds.add(devices...)
return nvds
}
// GetNvidiaDevices returns the set of NvidiaDevices on the machine
func GetNvidiaDevices() (NvidiaDevices, error) {
devicesFile, err := os.Open(procDevicesPath)
if os.IsNotExist(err) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("error opening devices file: %v", err)
}
defer devicesFile.Close()
return processDeviceFile(devicesFile), nil
}
func processDeviceFile(devicesFile io.Reader) NvidiaDevices {
nvidiaDevices := make(nvidiaDevices)
scanner := bufio.NewScanner(devicesFile)
for scanner.Scan() {
device, major, err := processProcDeviceLine(scanner.Text())
if err != nil {
log.Printf("Skipping line in devices file: %v", err)
continue
}
if strings.HasPrefix(device, nvidiaDevicePrefix) {
nvidiaDevices.add(Device{device, major})
}
}
return nvidiaDevices
}
func processProcDeviceLine(line string) (string, int, error) {
trimmed := strings.TrimSpace(line)
var name string
var major int
n, _ := fmt.Sscanf(trimmed, "%d %s", &major, &name)
if n == 2 {
return name, major, nil
}
return "", 0, fmt.Errorf("unparsable line: %v", line)
}

View File

@@ -0,0 +1,92 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package proc
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestNvidiaDevices(t *testing.T) {
devices := []Device{
{"nvidia-frontend", 195},
{"nvidia-nvlink", 234},
{"nvidia-caps", 235},
{"nvidia-uvm", 510},
{"nvidia-nvswitch", 511},
}
nvidiaDevices := NewMockNvidiaDevices(devices...)
for _, d := range devices {
device, exists := nvidiaDevices.Get(d.Name)
require.True(t, exists, "Unexpected missing device")
require.Equal(t, device.Name, d.Name, "Unexpected device name")
require.Equal(t, device.Major, d.Major, "Unexpected device major")
}
_, exists := nvidiaDevices.Get("bogus")
require.False(t, exists, "Unexpected 'bogus' device found")
}
func TestProcessDeviceFile(t *testing.T) {
testCases := []struct {
lines []string
expected []Device
}{
{[]string{}, []Device{}},
{[]string{"Not a valid line:"}, []Device{}},
{[]string{"195 nvidia-frontend"}, []Device{{"nvidia-frontend", 195}}},
{[]string{"195 nvidia-frontend", "235 nvidia-caps"}, []Device{{"nvidia-frontend", 195}, {"nvidia-caps", 235}}},
{[]string{" 195 nvidia-frontend"}, []Device{{"nvidia-frontend", 195}}},
{[]string{"Not a valid line:", "", "195 nvidia-frontend"}, []Device{{"nvidia-frontend", 195}}},
{[]string{"195 not-nvidia-frontend"}, []Device{}},
}
for _, tc := range testCases {
contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := processDeviceFile(contents)
require.Equalf(t, NewMockNvidiaDevices(tc.expected...), d, "testCase: %v", tc)
}
}
func TestProcessDeviceFileLine(t *testing.T) {
testCases := []struct {
line string
name string
major int
err bool
}{
{"", "", 0, true},
{"0", "", 0, true},
{"notint nvidia-frontend", "", 0, true},
{"195 nvidia-frontend", "nvidia-frontend", 195, false},
{" 195 nvidia-frontend", "nvidia-frontend", 195, false},
}
for _, tc := range testCases {
name, major, err := processProcDeviceLine(tc.line)
require.Equal(t, tc.name, name)
require.Equal(t, tc.major, major)
if tc.err {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
}

51
pkg/discover/binaries.go Normal file
View File

@@ -0,0 +1,51 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
log "github.com/sirupsen/logrus"
)
// NewBinaryMounts creates a discoverer for binaries using the specified root
func NewBinaryMounts(root string) Discover {
return NewBinaryMountsWithLogger(log.StandardLogger(), root)
}
// NewBinaryMountsWithLogger creates a Mounts discoverer as with NewBinaryMounts
// with the specified logger
func NewBinaryMountsWithLogger(logger *log.Logger, root string) Discover {
d := mounts{
logger: logger,
lookup: lookup.NewPathLocatorWithLogger(logger, root),
required: requiredBinaries,
}
return &d
}
// requiredBinaries defines a set of binaries and their labels
var requiredBinaries = map[string][]string{
"utility": {
"nvidia-smi", /* System management interface */
"nvidia-debugdump", /* GPU coredump utility */
"nvidia-persistenced", /* Persistence mode utility */
},
"compute": {
"nvidia-cuda-mps-control", /* Multi process service CLI */
"nvidia-cuda-mps-server", /* Multi process service server */
},
}

View File

@@ -0,0 +1,73 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestBinaries(t *testing.T) {
binaryLookup := map[string]string{
"nvidia-smi": "/usr/bin/nvidia-smi",
"nvidia-persistenced": "/usr/bin/nvidia-persistenced",
"nvidia-debugdump": "test-duplicates",
"nvidia-cuda-mps-control": "test-duplicates",
}
logger, _ := testlog.NewNullLogger()
d := NewBinaryMountsWithLogger(logger, "")
// Override lookup for testing
mockLocator := NewLocatorMockFromMap(binaryLookup)
d.(*mounts).lookup = mockLocator
mounts, err := d.Mounts()
require.NoError(t, err)
expected := []Mount{
{
Path: "/usr/bin/nvidia-smi",
},
{
Path: "/usr/bin/nvidia-persistenced",
},
{
Path: "test-duplicates",
},
}
require.Equal(t, len(expected), len(mounts))
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}
func TestNewBinariesConstructor(t *testing.T) {
b := NewBinaryMounts("").(*mounts)
require.NotNil(t, b.logger)
require.NotNil(t, b.lookup)
}

74
pkg/discover/composite.go Normal file
View File

@@ -0,0 +1,74 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import "fmt"
// composite is a discoverer that contains a list of Discoverers. The output of the
// Devices, Mounts, and Hooks functions is the concatenation of the output for each of the
// elements in the list.
type composite struct {
discoverers []Discover
}
var _ Discover = (*composite)(nil)
func (d composite) Devices() ([]Device, error) {
var allDevices []Device
for i, di := range d.discoverers {
devices, err := di.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices for discoverer %v: %v", i, err)
}
allDevices = append(allDevices, devices...)
}
return allDevices, nil
}
func (d composite) Mounts() ([]Mount, error) {
var allMounts []Mount
for i, di := range d.discoverers {
mounts, err := di.Mounts()
if err != nil {
return nil, fmt.Errorf("error discovering mounts for discoverer %v: %v", i, err)
}
allMounts = append(allMounts, mounts...)
}
return allMounts, nil
}
func (d composite) Hooks() ([]Hook, error) {
var allHooks []Hook
for i, di := range d.discoverers {
hooks, err := di.Hooks()
if err != nil {
return nil, fmt.Errorf("error discovering hooks for discoverer %v: %v", i, err)
}
allHooks = append(allHooks, hooks...)
}
return allHooks, nil
}
func (d *composite) add(di ...Discover) {
d.discoverers = append(d.discoverers, di...)
}

65
pkg/discover/discover.go Normal file
View File

@@ -0,0 +1,65 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
// DevicePath is a path in /dev associated with a device
type DevicePath string
// ProcPath is a path in /proc associated with a devices
type ProcPath string
// PCIBusID is the ID on the PCI bus of a device
type PCIBusID string
// DeviceNode represents a device on the file system
type DeviceNode struct {
Path DevicePath
Major int
Minor int
}
// Device represents a discovered device including identifiers (Index, UUID, PCI bus ID)
// for selection and paths in /dev and /proc associated with the device
type Device struct {
Index string
UUID string
PCIBusID PCIBusID
DeviceNodes []DeviceNode
ProcPaths []ProcPath
}
// Mount represents a discovered mount. This includes a set of labels
// for selection and the mount path
type Mount struct {
Path string
Labels map[string]string
}
// Hook represents a discovered hook
type Hook struct {
Path string
Args []string
HookName string
Labels map[string]string
}
// Discover defines an interface for discovering the devices and mounts available on a system
type Discover interface {
Devices() ([]Device, error)
Mounts() ([]Mount, error)
Hooks() ([]Hook, error)
}

View File

@@ -0,0 +1,424 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/proc"
log "github.com/sirupsen/logrus"
)
const (
// ControlDeviceUUID is used as the UUID for control devices such as nvidiactl or nvidia-modeset
ControlDeviceUUID = "CONTROL"
// MIGConfigDeviceUUID is used to indicate the MIG config control device
MIGConfigDeviceUUID = "CONFIG"
// MIGMonitorDeviceUUID is used to indicate the MIG monitor control device
MIGMonitorDeviceUUID = "MONITOR"
nvidiaGPUDeviceName = "nvidia-frontend"
nvidiaCapsDeviceName = "nvidia-caps"
nvidiaUVMDeviceName = "nvidia-uvm"
)
type nvmlDiscover struct {
None
logger *log.Logger
nvml nvml.Interface
migCaps map[ProcPath]DeviceNode
nvidiaDevices proc.NvidiaDevices
}
var _ Discover = (*nvmlDiscover)(nil)
// NewNVMLDiscover constructs a discoverer that uses NVML to find the devices
// available on a system.
func NewNVMLDiscover(nvml nvml.Interface) (Discover, error) {
return NewNVMLDiscoverWithLogger(log.StandardLogger(), nvml)
}
// NewNVMLDiscoverWithLogger constructs a discovered as with NewNVMLDiscover with the specified
// logger
func NewNVMLDiscoverWithLogger(logger *log.Logger, nvml nvml.Interface) (Discover, error) {
nvidiaDevices, err := proc.GetNvidiaDevices()
if err != nil {
return nil, fmt.Errorf("error loading NVIDIA devices: %v", err)
}
var migCaps map[ProcPath]DeviceNode
nvcapsDevice, exists := nvidiaDevices.Get(nvidiaCapsDeviceName)
if !exists {
logger.Warnf("%v nvcaps device could not be found", nvidiaCapsDeviceName)
} else if migCaps, err = getMigCaps(nvcapsDevice.Major); err != nil {
logger.Warnf("Could not load MIG capability devices: %v", err)
migCaps = nil
}
discover := &nvmlDiscover{
logger: logger,
nvml: nvml,
migCaps: migCaps,
nvidiaDevices: nvidiaDevices,
}
return discover, nil
}
// hasMigSupport checks if MIG device discovery is supported.
// Cases where this will be disabled include where no MIG minors file is
// present.
func (d nvmlDiscover) hasMigSupport() bool {
return len(d.migCaps) > 0
}
func (d *nvmlDiscover) Devices() ([]Device, error) {
ret := d.nvml.Init()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error initalizing NVML: %v", ret.Error())
}
defer d.tryShutdownNVML()
c, ret := d.nvml.DeviceGetCount()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device count: %v", ret.Error())
}
var handles []nvml.Device
for i := 0; i < c; i++ {
handle, ret := d.nvml.DeviceGetHandleByIndex(i)
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for device %v: %v", i, ret.Error())
}
if !d.hasMigSupport() {
handles = append(handles, handle)
continue
}
migHandles, err := getMIGHandlesForDevice(handle)
if err != nil {
return nil, fmt.Errorf("error getting MIG handles for device: %v", err)
}
if len(migHandles) == 0 {
handles = append(handles, handle)
}
handles = append(handles, migHandles...)
}
return d.devicesByHandle(handles)
}
func (d *nvmlDiscover) devicesByHandle(handles []nvml.Device) ([]Device, error) {
var devices []Device
var largestMinorNumber int
for _, h := range handles {
device, err := d.deviceFromNVMLHandle(h)
if err != nil {
return nil, fmt.Errorf("error constructing device from handle %v: %v", h, err)
}
devices = append(devices, device)
if largestMinorNumber < device.DeviceNodes[0].Minor {
largestMinorNumber = device.DeviceNodes[0].Minor
}
}
controlDevices, err := d.getControlDevices()
if err != nil {
return nil, fmt.Errorf("error getting control devices: %v", err)
}
devices = append(devices, controlDevices...)
if d.hasMigSupport() {
migControlDevices, err := d.getMigControlDevices(largestMinorNumber)
if err != nil {
return nil, fmt.Errorf("error getting MIG control devices: %v", err)
}
devices = append(devices, migControlDevices...)
}
return devices, nil
}
func (d *nvmlDiscover) deviceFromNVMLHandle(handle nvml.Device) (Device, error) {
if d.hasMigSupport() {
isMigDevice, ret := handle.IsMigDeviceHandle()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error checking device handle: %v", ret.Error())
}
if isMigDevice {
return d.deviceFromMIGDeviceHandle(handle)
}
}
return d.deviceFromFullDeviceHandle(handle)
}
func (d *nvmlDiscover) deviceFromFullDeviceHandle(handle nvml.Device) (Device, error) {
index, ret := handle.GetIndex()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
}
uuid, ret := handle.GetUUID()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
}
pciInfo, ret := handle.GetPciInfo()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting PCI info: %v", ret.Error())
}
busID := NewPCIBusID(pciInfo)
minor, ret := handle.GetMinorNumber()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting minor number: %v", ret.Error())
}
nvidiaGPUDevice, exists := d.nvidiaDevices.Get(nvidiaGPUDeviceName)
if !exists {
return Device{}, fmt.Errorf("device for '%v' does not exist", nvidiaGPUDeviceName)
}
deviceNode := DeviceNode{
Path: DevicePath(fmt.Sprintf("/dev/nvidia%d", minor)),
Major: nvidiaGPUDevice.Major,
Minor: minor,
}
device := Device{
Index: fmt.Sprintf("%d", index),
PCIBusID: busID,
UUID: uuid,
DeviceNodes: []DeviceNode{deviceNode},
ProcPaths: []ProcPath{busID.GetProcPath()},
}
return device, nil
}
func (d *nvmlDiscover) deviceFromMIGDeviceHandle(handle nvml.Device) (Device, error) {
parent, ret := handle.GetDeviceHandleFromMigDeviceHandle()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting parent device handle: %v", ret.Error())
}
gpu, ret := parent.GetMinorNumber()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting GPU minor number: %v", ret.Error())
}
parentDevice, err := d.deviceFromFullDeviceHandle(parent)
if err != nil {
return Device{}, fmt.Errorf("error getting parent device: %v", err)
}
index, ret := handle.GetIndex()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
}
uuid, ret := handle.GetUUID()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
}
capDeviceNodes := []DeviceNode{}
procPaths, err := getProcPathsForMigDevice(gpu, handle)
if err != nil {
return Device{}, fmt.Errorf("error getting proc paths for MIG device: %v", err)
}
for _, p := range procPaths {
capDeviceNode, ok := d.migCaps[p]
if !ok {
return Device{}, fmt.Errorf("could not determine cap device path for %v", p)
}
capDeviceNodes = append(capDeviceNodes, capDeviceNode)
}
device := Device{
Index: fmt.Sprintf("%s:%d", parentDevice.Index, index),
UUID: uuid,
DeviceNodes: append(parentDevice.DeviceNodes, capDeviceNodes...),
ProcPaths: append(parentDevice.ProcPaths, procPaths...),
}
return device, nil
}
func (d *nvmlDiscover) getControlDevices() ([]Device, error) {
devices := []struct {
name string
path string
minor int
}{
// TODO: Where is the best place to find these device Minors programatically?
{nvidiaGPUDeviceName, "/dev/nvidia-modeset", 254},
{nvidiaGPUDeviceName, "/dev/nvidiactl", 255},
{nvidiaUVMDeviceName, "/dev/nvidia-uvm", 0},
{nvidiaUVMDeviceName, "/dev/nvidia-uvm-tools", 1},
}
var controlDevices []Device
for _, info := range devices {
device, exists := d.nvidiaDevices.Get(info.name)
if !exists {
d.logger.Warnf("device name %v not defined; skipping control devices %v", info.name, info.path)
continue
}
deviceNode := DeviceNode{
Path: DevicePath(info.path),
Major: device.Major,
Minor: info.minor,
}
controlDevices = append(controlDevices, Device{
UUID: ControlDeviceUUID,
DeviceNodes: []DeviceNode{deviceNode},
ProcPaths: []ProcPath{},
})
}
return controlDevices, nil
}
func (d *nvmlDiscover) getMigControlDevices(numGpus int) ([]Device, error) {
targets := map[string]ProcPath{
MIGConfigDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/config"),
MIGMonitorDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"),
}
var devices []Device
for id, procPath := range targets {
deviceNode, exists := d.migCaps[procPath]
if !exists {
return nil, fmt.Errorf("device node for '%v' is undefined", procPath)
}
var procPaths []ProcPath
for gpu := 0; gpu <= numGpus; gpu++ {
procPaths = append(procPaths, ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig", gpu)))
}
procPaths = append(procPaths, procPath)
devices = append(devices, Device{
UUID: id,
DeviceNodes: []DeviceNode{deviceNode},
ProcPaths: procPaths,
})
}
return devices, nil
}
func getProcPathsForMigDevice(gpu int, handle nvml.Device) ([]ProcPath, error) {
gi, ret := handle.GetGPUInstanceId()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU instance ID: %v", ret.Error())
}
ci, ret := handle.GetComputeInstanceId()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting comput instance ID: %v", ret.Error())
}
procPaths := []ProcPath{
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/access", gpu, gi)),
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/ci%d/access", gpu, gi, ci)),
}
return procPaths, nil
}
func getMIGHandlesForDevice(handle nvml.Device) ([]nvml.Device, error) {
currentMigMode, _, ret := handle.GetMigMode()
if ret.Value() == nvml.ERROR_NOT_SUPPORTED {
return nil, nil
}
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG mode for device: %v", ret.Error())
}
if currentMigMode == nvml.DEVICE_MIG_DISABLE {
return nil, nil
}
maxMigDeviceCount, ret := handle.GetMaxMigDeviceCount()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting number of MIG devices: %v", ret.Error())
}
var migHandles []nvml.Device
for mi := 0; mi < maxMigDeviceCount; mi++ {
migHandle, ret := handle.GetMigDeviceHandleByIndex(mi)
if ret.Value() == nvml.ERROR_NOT_FOUND {
continue
}
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device %v: %v", mi, ret.Error())
}
migHandles = append(migHandles, migHandle)
}
return migHandles, nil
}
func (d *nvmlDiscover) tryShutdownNVML() {
ret := d.nvml.Shutdown()
if ret.Value() != nvml.SUCCESS {
d.logger.Warnf("Could not shut down NVML: %v", ret.Error())
}
}
// NewPCIBusID provides a utility function that returns the string representation
// of the bus ID.
func NewPCIBusID(p nvml.PciInfo) PCIBusID {
var bytes []byte
for _, b := range p.BusId {
if byte(b) == '\x00' {
break
}
bytes = append(bytes, byte(b))
}
return PCIBusID(string(bytes))
}
// GetProcPath returns the path in /proc associated with the PCI bus ID
func (p PCIBusID) GetProcPath() ProcPath {
id := strings.ToLower(p.String())
if strings.HasPrefix(id, "0000") {
id = id[4:]
}
return ProcPath(filepath.Join("/proc/driver/nvidia/gpus", id))
}
func (p PCIBusID) String() string {
return string(p)
}

View File

@@ -0,0 +1,44 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
)
// getMigCaps returns a mapping of MIG capability path to device nodes
func getMigCaps(capDeviceMajor int) (map[ProcPath]DeviceNode, error) {
migCaps, err := nvcaps.LoadMigMinors()
if err != nil {
return nil, fmt.Errorf("error loading MIG minors: %v", err)
}
return getMigCapsFromMigMinors(migCaps, capDeviceMajor), nil
}
func getMigCapsFromMigMinors(migCaps map[nvcaps.MigCap]nvcaps.MigMinor, capDeviceMajor int) map[ProcPath]DeviceNode {
capsDevicePaths := make(map[ProcPath]DeviceNode)
for cap, minor := range migCaps {
capsDevicePaths[ProcPath(cap.ProcPath())] = DeviceNode{
Path: DevicePath(minor.DevicePath()),
Major: capDeviceMajor,
Minor: int(minor),
}
}
return capsDevicePaths
}

View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
"github.com/stretchr/testify/require"
)
func TestGetMigCaps(t *testing.T) {
migMinors := map[nvcaps.MigCap]nvcaps.MigMinor{
"config": 1,
"monitor": 2,
"gpu0/gi0/access": 3,
"gpu0/gi0/ci0/access": 4,
}
migCapMajor := 999
migCaps := getMigCapsFromMigMinors(migMinors, migCapMajor)
require.Len(t, migCaps, len(migMinors))
for _, c := range migCaps {
require.Equal(t, migCapMajor, c.Major)
}
}

View File

@@ -0,0 +1,219 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/proc"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
const (
nvidiaGPUDeviceMajorDefault = 195
nvidiaCapsDeviceMajorDefault = 235
)
func newTestDiscover() nvmlDiscover {
logger, _ := testlog.NewNullLogger()
nvml := nvml.NewMockNVMLServer(nvml.NewMockA100Device(0))
return nvmlDiscover{
logger: logger,
nvml: nvml,
nvidiaDevices: proc.NewMockNvidiaDevices(
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
),
}
}
func newTestDiscoverWithMIG() nvmlDiscover {
device := &nvml.MockA100Device{
Index: 0,
MigMode: nvml.DEVICE_MIG_ENABLE,
GpuInstances: make(map[*nvml.MockA100GpuInstance]struct{}),
GpuInstanceCounter: 0,
}
// Create a single gi and ci on the device
gpuInstanceProfileInfo := &nvml.GpuInstanceProfileInfo{
Id: nvml.GPU_INSTANCE_PROFILE_7_SLICE,
}
gi, _ := device.CreateGpuInstance(gpuInstanceProfileInfo)
computeInstanceProfileInfo := &nvml.ComputeInstanceProfileInfo{
Id: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE,
}
_, _ = gi.CreateComputeInstance(computeInstanceProfileInfo)
logger, _ := testlog.NewNullLogger()
return nvmlDiscover{
logger: logger,
nvml: nvml.NewMockNVMLServer(device),
migCaps: map[ProcPath]DeviceNode{
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
Minor: 3,
Major: 235,
},
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
Minor: 4,
Major: 235,
},
ProcPath("/proc/driver/nvidia/capabilities/mig/config"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
Minor: 1,
Major: 235,
},
ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
Minor: 2,
Major: 235,
},
},
nvidiaDevices: proc.NewMockNvidiaDevices(
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
proc.Device{Name: nvidiaCapsDeviceName, Major: nvidiaCapsDeviceMajorDefault},
),
}
}
func TestDiscoverNvmlDevices(t *testing.T) {
d := newTestDiscover()
devices, err := d.Devices()
require.NoError(t, err)
require.Len(t, devices, 3)
device := devices[0]
require.Equal(t, "0", device.Index)
require.Equal(t, "GPU-0", device.UUID)
require.Equal(t, "0000FFFF:FF:FF.F", device.PCIBusID.String())
expectedDeviceNodes := []DeviceNode{
{
Path: DevicePath("/dev/nvidia0"),
Minor: 0,
Major: nvidiaGPUDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, device.DeviceNodes)
expectedProcPaths := []ProcPath{ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f")}
require.Equal(t, expectedProcPaths, device.ProcPaths)
}
func TestDiscoverNvmlMigDevices(t *testing.T) {
d := newTestDiscoverWithMIG()
devices, err := d.Devices()
require.NoError(t, err)
require.Len(t, devices, 5)
mig := devices[0]
require.Equal(t, "0:0", mig.Index)
require.Equal(t, "MIG-0", mig.UUID)
require.Empty(t, mig.PCIBusID)
expectedDeviceNodes := []DeviceNode{
{
Path: DevicePath("/dev/nvidia0"),
Minor: 0,
Major: nvidiaGPUDeviceMajorDefault,
},
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
Minor: 3,
Major: nvidiaCapsDeviceMajorDefault,
},
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
Minor: 4,
Major: nvidiaCapsDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, mig.DeviceNodes)
expectedProcPaths := []ProcPath{
ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f"),
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"),
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"),
}
require.Equal(t, expectedProcPaths, mig.ProcPaths)
var config *Device
var monitor *Device
for i, d := range devices {
if d.UUID == "CONFIG" {
config = &devices[i]
}
if d.UUID == "MONITOR" {
monitor = &devices[i]
}
}
require.NotNil(t, config)
require.NotNil(t, monitor)
require.Equal(t, "CONFIG", config.UUID)
expectedDeviceNodes = []DeviceNode{
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
Minor: 1,
Major: nvidiaCapsDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, config.DeviceNodes)
require.Contains(t, config.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/config"))
require.Len(t, config.ProcPaths, 2)
require.Equal(t, "MONITOR", monitor.UUID)
expectedDeviceNodes = []DeviceNode{
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
Minor: 2,
Major: nvidiaCapsDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, monitor.DeviceNodes)
require.Contains(t, monitor.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"))
require.Len(t, monitor.ProcPaths, 2)
}
func TestPCIBusID(t *testing.T) {
testCases := map[string]ProcPath{
"0000FFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffff:ff:ff.f",
"FFFFFFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffffffff:ff:ff.f",
}
for busID, procPath := range testCases {
p := PCIBusID(busID)
require.Equal(t, busID, p.String())
require.Equal(t, procPath, p.GetProcPath())
}
}

71
pkg/discover/hooks.go Normal file
View File

@@ -0,0 +1,71 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
log "github.com/sirupsen/logrus"
)
type hooks struct {
None
logger *log.Logger
}
var _ Discover = (*hooks)(nil)
// NewHooks creates a discoverer for linux containers
func NewHooks() Discover {
return NewHooksWithLogger(log.StandardLogger())
}
// NewHooksWithLogger creates a discoverer as with NewHooks with the specified logger
func NewHooksWithLogger(logger *log.Logger) Discover {
h := hooks{
logger: logger,
}
return &h
}
func (h hooks) Hooks() ([]Hook, error) {
var hooks []Hook
hooks = append(hooks, newLdconfigHook())
return hooks, nil
}
func newLdconfigHook() Hook {
const rootPattern = "@Root.Path@"
h := Hook{
Path: "/sbin/ldconfig",
Args: []string{
// TODO: Testing seems to indicate that this is -v flag is required
"-v",
"-r", rootPattern,
},
// TODO: CreateContainer hooks were only added to a later OCI spec version
// We will have to find a way to deal with OCI versions before 1.0.2
HookName: "create-container",
Labels: map[string]string{
"min-oci-version": "1.0.2",
},
}
return h
}

52
pkg/discover/ipc.go Normal file
View File

@@ -0,0 +1,52 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
log "github.com/sirupsen/logrus"
)
// NewIPCMounts creates a discoverer for IPC sockets
func NewIPCMounts(root string) Discover {
return NewIPCMountsWithLogger(log.StandardLogger(), root)
}
// NewIPCMountsWithLogger creates a discovered as with NewIPCMounts with the
// specified logger.
func NewIPCMountsWithLogger(logger *log.Logger, root string) Discover {
d := mounts{
logger: logger,
lookup: lookup.NewFileLocatorWithLogger(logger, root),
required: requiredIPCs,
}
return &d
}
var requiredIPCs = map[string][]string{
"nvidia-persistenced": {
"/var/run/nvidia-persistenced/socket",
},
"nvidia-fabricmanager": {
"/var/run/nvidia-fabricmanager/socket",
},
// TODO: This can be controlled by the NV_MPS_PIPE_DIR envvar
"nvidia-mps": {
"/tmp/nvidia-mps",
},
}

56
pkg/discover/ipc_test.go Normal file
View File

@@ -0,0 +1,56 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestIPCDiscover(t *testing.T) {
ipcLookup := map[string]string{
"/var/run/nvidia-persistenced/socket": "/var/run/nvidia-persistenced/socket",
"/var/run/nvidia-fabricmanager/socket": "fm-socket",
}
logger, _ := testlog.NewNullLogger()
d := NewIPCMountsWithLogger(logger, "")
// Override lookup for testing
mockLocator := NewLocatorMockFromMap(ipcLookup)
d.(*mounts).lookup = mockLocator
mounts, err := d.Mounts()
require.NoError(t, err)
require.ElementsMatch(t, []Mount{
{Path: "/var/run/nvidia-persistenced/socket", Labels: map[string]string{}},
{Path: "fm-socket", Labels: map[string]string{}}}, mounts)
require.Len(t, mockLocator.LocateCalls(), 3)
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}

100
pkg/discover/libraries.go Normal file
View File

@@ -0,0 +1,100 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
log "github.com/sirupsen/logrus"
)
// NewLibraries constructs discoverer for libraries
func NewLibraries(root string) (Discover, error) {
return NewLibrariesWithLogger(log.StandardLogger(), root)
}
// NewLibrariesWithLogger constructs discoverer for libraries with the specified logger
func NewLibrariesWithLogger(logger *log.Logger, root string) (Discover, error) {
lookup, err := lookup.NewLibraryLocatorWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error constructing locator: %v", err)
}
d := mounts{
logger: logger,
lookup: lookup,
required: requiredLibraries,
}
return &d, nil
}
// requiredLibraries defines a set of libraries and their labels
var requiredLibraries = map[string][]string{
"utility": {
"libnvidia-ml.so", /* Management library */
"libnvidia-cfg.so", /* GPU configuration */
},
"compute": {
"libcuda.so", /* CUDA driver library */
"libnvidia-opencl.so", /* NVIDIA OpenCL ICD */
"libnvidia-ptxjitcompiler.so", /* PTX-SASS JIT compiler (used by libcuda) */
"libnvidia-fatbinaryloader.so", /* fatbin loader (used by libcuda) */
"libnvidia-allocator.so", /* NVIDIA allocator runtime library */
"libnvidia-compiler.so", /* NVVM-PTX compiler for OpenCL (used by libnvidia-opencl) */
},
"video": {
"libvdpau_nvidia.so", /* NVIDIA VDPAU ICD */
"libnvidia-encode.so", /* Video encoder */
"libnvidia-opticalflow.so", /* NVIDIA Opticalflow library */
"libnvcuvid.so", /* Video decoder */
},
"graphics": {
//"libnvidia-egl-wayland.so", /* EGL wayland platform extension (used by libEGL_nvidia) */
"libnvidia-eglcore.so", /* EGL core (used by libGLES*[_nvidia] and libEGL_nvidia) */
"libnvidia-glcore.so", /* OpenGL core (used by libGL or libGLX_nvidia) */
"libnvidia-tls.so", /* Thread local storage (used by libGL or libGLX_nvidia) */
"libnvidia-glsi.so", /* OpenGL system interaction (used by libEGL_nvidia) */
"libnvidia-fbc.so", /* Framebuffer capture */
"libnvidia-ifr.so", /* OpenGL framebuffer capture */
"libnvidia-rtcore.so", /* Optix */
"libnvoptix.so", /* Optix */
},
"glvnd": {
//"libGLX.so", /* GLX ICD loader */
//"libOpenGL.so", /* OpenGL ICD loader */
//"libGLdispatch.so", /* OpenGL dispatch (used by libOpenGL, libEGL and libGLES*) */
"libGLX_nvidia.so", /* OpenGL/GLX ICD */
"libEGL_nvidia.so", /* EGL ICD */
"libGLESv2_nvidia.so", /* OpenGL ES v2 ICD */
"libGLESv1_CM_nvidia.so", /* OpenGL ES v1 common profile ICD */
"libnvidia-glvkspirv.so", /* SPIR-V Lib for Vulkan */
"libnvidia-cbl.so", /* VK_NV_ray_tracing */
},
"compat32": {
"libGL.so", /* OpenGL/GLX legacy _or_ compatibility wrapper (GLVND) */
"libEGL.so", /* EGL legacy _or_ ICD loader (GLVND) */
"libGLESv1_CM.so", /* OpenGL ES v1 common profile legacy _or_ ICD loader (GLVND) */
"libGLESv2.so", /* OpenGL ES v2 legacy _or_ ICD loader (GLVND) */
},
"ngx": {
"libnvidia-ngx.so", /* NGX library */
},
"dxcore": {
"libdxcore.so", /* Core library for dxcore support */
},
}

View File

@@ -0,0 +1,56 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestLibraries(t *testing.T) {
libraryLookup := map[string]string{
"libcuda.so": "/lib/libcuda.so.999.99",
"libversion.so": "/lib/libversion.so.111.11",
"libother.so": "/lib/libother.so.999.99",
}
logger, _ := testlog.NewNullLogger()
d, err := NewLibrariesWithLogger(logger, "")
require.NoError(t, err)
// Override lookup for testing
mockLocator := NewLocatorMockFromMap(libraryLookup)
d.(*mounts).lookup = mockLocator
mounts, err := d.Mounts()
require.NoError(t, err)
require.ElementsMatch(t, []Mount{{Path: "/lib/libcuda.so.999.99", Labels: map[string]string{}}}, mounts)
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}

125
pkg/discover/mounts.go Normal file
View File

@@ -0,0 +1,125 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
log "github.com/sirupsen/logrus"
)
const (
capabilityLabel = "capability"
versionLabel = "version"
)
// mounts is a generic discoverer for Mounts. It is customized by specifying the
// required entities as a key-value pair as well as a Locator that is used to
// identify the mounts that are to be included.
type mounts struct {
None
logger *log.Logger
lookup lookup.Locator
required map[string][]string
}
var _ Discover = (*mounts)(nil)
func (d mounts) Mounts() ([]Mount, error) {
mounts, err := d.uniqueMounts()
if err != nil {
return nil, fmt.Errorf("error discovering mounts: %v", err)
}
return mounts.Slice(), nil
}
func (d mounts) uniqueMounts() (mountsByPath, error) {
if d.lookup == nil {
return nil, fmt.Errorf("no lookup defined")
}
mounts := make(mountsByPath)
for id, keys := range d.required {
for _, key := range keys {
d.logger.Debugf("Locating %v [%v]", key, id)
located, err := d.lookup.Locate(key)
if err != nil {
d.logger.Warnf("Could not locate %v [%v]: %v", key, id, err)
continue
}
d.logger.Infof("Located %v [%v]: %v", key, id, located)
for _, p := range located {
// TODO: We need to add labels
mount := newMount(p)
mounts.Put(mount)
}
}
}
return mounts, nil
}
type mountsByPath map[string]Mount
func (m mountsByPath) Slice() []Mount {
var mounts []Mount
for _, mount := range m {
mounts = append(mounts, mount)
}
return mounts
}
func (m *mountsByPath) Put(value Mount) {
key := value.Path
mount, exists := (*m)[key]
if !exists {
(*m)[key] = value
return
}
for k, v := range value.Labels {
mount.Labels[k] = v
}
(*m)[key] = mount
}
// NewMountForCapability creates a mount with the specified capability label
func NewMountForCapability(path string, capability string) Mount {
return newMount(path, capabilityLabel, capability)
}
// NewMountForVersion creates a mount with the specified version label
func NewMountForVersion(path string, version string) Mount {
return newMount(path, versionLabel, version)
}
func newMount(path string, labels ...string) Mount {
l := make(map[string]string)
for i := 0; i < len(labels)-1; i += 2 {
l[labels[i]] = labels[i+1]
}
return Mount{
Path: path,
Labels: l,
}
}

View File

@@ -0,0 +1,53 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/stretchr/testify/require"
)
func TestMountsReturnsErrorForNoLookup(t *testing.T) {
d := mounts{}
mounts, err := d.Mounts()
require.Error(t, err)
require.Len(t, mounts, 0)
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}
func NewLocatorMockFromMap(lookupMap map[string]string) *lookup.LocatorMock {
return &lookup.LocatorMock{
LocateFunc: func(key string) ([]string, error) {
value, exists := lookupMap[key]
if !exists {
return nil, fmt.Errorf("key %v not found", key)
}
return []string{value}, nil
},
}
}

38
pkg/discover/none.go Normal file
View File

@@ -0,0 +1,38 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
// None is a null discoverer that returns an empty list of devices and
// mounts.
type None struct{}
var _ Discover = (*None)(nil)
// Devices returns an empty list of devices
func (e None) Devices() ([]Device, error) {
return []Device{}, nil
}
// Mounts returns an empty list of mounts
func (e None) Mounts() ([]Mount, error) {
return []Mount{}, nil
}
// Hooks returns an empty list of hooks
func (e None) Hooks() ([]Hook, error) {
return []Hook{}, nil
}

39
pkg/discover/none_test.go Normal file
View File

@@ -0,0 +1,39 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNone(t *testing.T) {
d := None{}
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
mounts, err := d.Mounts()
require.NoError(t, err)
require.Empty(t, mounts)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}

View File

@@ -0,0 +1,71 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
log "github.com/sirupsen/logrus"
)
type nvmlServer struct {
logger *log.Logger
composite
}
var _ Discover = (*nvmlServer)(nil)
// NewNVMLServer constructs a discoverer for server systems using NVML to discover devices
func NewNVMLServer(root string) (Discover, error) {
return NewNVMLServerWithLogger(log.StandardLogger(), root)
}
// NewNVMLServerWithLogger constructs a discoverer for server systems using NVML to discover devices with
// the specified logger
func NewNVMLServerWithLogger(logger *log.Logger, root string) (Discover, error) {
return createNVMLServer(logger, nvml.New(), root)
}
func createNVMLServer(logger *log.Logger, nvml nvml.Interface, root string) (Discover, error) {
d := nvmlServer{
logger: logger,
}
devices, err := NewNVMLDiscoverWithLogger(logger, nvml)
if err != nil {
return nil, fmt.Errorf("error constructing NVML device discoverer: %v", err)
}
libraries, err := NewLibrariesWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error constructing library discoverer: %v", err)
}
d.add(
// Device discovery
devices,
// Mounts discovery
libraries,
NewBinaryMountsWithLogger(logger, root),
NewIPCMountsWithLogger(logger, root),
// Hook discovery
NewHooksWithLogger(logger),
)
return &d, nil
}

View File

@@ -0,0 +1,66 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/proc"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
const (
testMajor = 999
)
func TestNVMLServerConstructor(t *testing.T) {
logger, _ := testlog.NewNullLogger()
nvml := nvml.NewMockNVMLOnLunaServer()
d, err := createNVMLServer(logger, nvml, "")
require.NoError(t, err)
instance := d.(*nvmlServer)
require.Len(t, instance.discoverers, 1+3+1)
// We need to override the nvidiaDevices member of the nvmlDiscovery
// TODO: Use a mock instead, or allow for injection into a constructor
instance.discoverers[0].(*nvmlDiscover).nvidiaDevices = &mockNvidiaDevices{}
devices, err := d.Devices()
require.NoError(t, err)
require.NotEmpty(t, devices)
_, err = d.Mounts()
require.NoError(t, err)
}
type mockNvidiaDevices struct{}
var _ proc.NvidiaDevices = (*mockNvidiaDevices)(nil)
func (d mockNvidiaDevices) Get(name string) (proc.Device, bool) {
return proc.Device{Name: name, Major: testMajor}, true
}
func (d mockNvidiaDevices) Exists(string) bool {
return false
}

129
pkg/ensure/devices.go Normal file
View File

@@ -0,0 +1,129 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package ensure
import (
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
log "github.com/sirupsen/logrus"
)
type ensureDevices struct {
logger *log.Logger
discover.Discover
lookup lookup.Locator
root string
}
// NewEnsureDevices creates a discoverer that wraps the specified discoverer and ensures that the
// device nodes for the discoverer are created. If a root is specified, the device nodes
// rooted there are also created.
func NewEnsureDevices(d discover.Discover, root string) discover.Discover {
return NewEnsureDevicesWithLogger(log.StandardLogger(), d, root)
}
// NewEnsureDevicesWithLogger creates a discoverer that wraps the specified discoverer and ensures that the
// device nodes for the discoverer are created. If a root is specified, the device nodes
// rooted there are also created. The specified logger is used.
func NewEnsureDevicesWithLogger(logger *log.Logger, d discover.Discover, root string) discover.Discover {
e := ensureDevices{
Discover: d,
logger: logger,
lookup: lookup.NewPathLocatorWithLogger(logger, root),
root: root,
}
return &e
}
func (d ensureDevices) Devices() ([]discover.Device, error) {
devices, err := d.Discover.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
for _, di := range devices {
for _, dn := range di.DeviceNodes {
d.deviceNode(dn)
}
}
return devices, nil
}
func (d ensureDevices) deviceNode(dn discover.DeviceNode) error {
err := d.device(dn.Path, dn.Major, dn.Minor)
if err != nil {
d.logger.Errorf("Error creating device node %+v: %v", dn, err)
}
if d.root != "" && d.root != "/" {
rootedPath := discover.DevicePath(filepath.Join(d.root, string(dn.Path)))
err = d.device(rootedPath, dn.Major, dn.Minor)
if err != nil {
d.logger.Errorf("Error creating device node %+v: %v", dn, err)
}
}
return nil
}
func (d ensureDevices) device(path discover.DevicePath, major int, minor int) error {
// TODO: We may want to check that the device node has the required permissions
_, err := os.Stat(string(path))
if err == nil {
d.logger.Infof("Device node %v already exists", path)
return nil
}
if !errors.Is(err, os.ErrNotExist) {
d.logger.Errorf("Error getting info for device node %v: %v", path, err)
return fmt.Errorf("error getting device node info: %w", err)
}
// See: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#runfile-verifications
// TODO: We should use nvidia-modprobe or a tool based off that instead
args := []string{
"-m", "666",
string(path), "c",
fmt.Sprint(major),
fmt.Sprint(minor),
}
return d.run("mknod", args...)
}
func (d ensureDevices) run(cmd string, args ...string) error {
paths, err := d.lookup.Locate(cmd)
if err != nil {
return fmt.Errorf("error finding command %v: %v", cmd, err)
}
if len(paths) == 0 {
return fmt.Errorf("command %v not found in path", cmd)
}
path := paths[0]
d.logger.Debugf("Running %v", append([]string{path}, args...))
err = exec.Command(path, args...).Run()
if err != nil {
return fmt.Errorf("error running %v: %v", append([]string{path}, args...), err)
}
return nil
}

22
pkg/ensure/ensure.go Normal file
View File

@@ -0,0 +1,22 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package ensure
import "github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
// Ensure is an alias for Discover
type Ensure discover.Discover

41
pkg/filter/all.go Normal file
View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
type all struct {
selectors []Selector
}
// All returns a selector that evaluates true if EACH of the specified selectors
// are selected.
func All(selectors ...Selector) Selector {
s := all{
selectors: selectors,
}
return &s
}
func (s all) Selected(device discover.Device) bool {
for _, si := range s.selectors {
if !si.Selected(device) {
return false
}
}
return true
}

76
pkg/filter/all_test.go Normal file
View File

@@ -0,0 +1,76 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/stretchr/testify/require"
)
func TestAll(t *testing.T) {
True := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return true
},
}
False := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return false
},
}
d := discover.Device{}
// Ensure that the mocks are set up correctly:
require.True(t, True.Selected(d))
require.False(t, False.Selected(d))
emtpy := All()
require.True(t, emtpy.Selected(d))
s00 := All(False, False)
require.False(t, s00.Selected(d))
s01 := All(False, True)
require.False(t, s01.Selected(d))
s10 := All(True, False)
require.False(t, s10.Selected(d))
s11 := All(True, True)
require.True(t, s11.Selected(d))
}
type discoverMock struct {
discover.None
devices []discover.Device
devicesError error
mounts []discover.Mount
mountsError error
}
var _ discover.Discover = (*discoverMock)(nil)
func (m discoverMock) Devices() ([]discover.Device, error) {
return m.devices, m.devicesError
}
func (m discoverMock) Mounts() ([]discover.Mount, error) {
return m.mounts, m.mountsError
}

41
pkg/filter/any.go Normal file
View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
type any struct {
selectors []Selector
}
// Any returns a selector that evaluates true if ANY of the specified selectors
// are selected
func Any(selectors ...Selector) Selector {
s := any{
selectors: selectors,
}
return &s
}
func (s any) Selected(device discover.Device) bool {
for _, si := range s.selectors {
if si.Selected(device) {
return true
}
}
return false
}

58
pkg/filter/any_test.go Normal file
View File

@@ -0,0 +1,58 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. Any rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/stretchr/testify/require"
)
func TestAny(t *testing.T) {
True := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return true
},
}
False := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return false
},
}
d := discover.Device{}
// Ensure that the mocks are set up correctly:
require.True(t, True.Selected(d))
require.False(t, False.Selected(d))
emtpy := Any()
require.False(t, emtpy.Selected(d))
s00 := Any(False, False)
require.False(t, s00.Selected(d))
s01 := Any(False, True)
require.True(t, s01.Selected(d))
s10 := Any(True, False)
require.True(t, s10.Selected(d))
s11 := Any(True, True)
require.True(t, s11.Selected(d))
}

60
pkg/filter/by_id.go Normal file
View File

@@ -0,0 +1,60 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
)
type devicesByID map[string]struct{}
var _ Selector = (*devicesByID)(nil)
// NewDeviceSelector creates a selector for devices based on the specified IDs.
func NewDeviceSelector(ids ...string) Selector {
deviceIDs := make(devicesByID)
for _, id := range ids {
deviceIDs[id] = struct{}{}
}
return deviceIDs
}
// Selected checks whether a specific device is included in the set of devicesIDs
// The device is checked by UUID, Index, and PCIBusID and if any of these match
// the device is considered selected.
func (d devicesByID) Selected(device discover.Device) bool {
var exists bool
_, exists = d[device.UUID]
if exists {
return true
}
_, exists = d[device.Index]
if exists {
return true
}
_, exists = d[device.PCIBusID.String()]
if exists {
return true
}
return false
}

47
pkg/filter/by_id_test.go Normal file
View File

@@ -0,0 +1,47 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/stretchr/testify/require"
)
func TestDeviceByID(t *testing.T) {
device := discover.Device{
Index: "index",
UUID: "uuid",
PCIBusID: discover.PCIBusID("pcibusid"),
}
require.False(t, NewDeviceSelector().Selected(device))
require.False(t, NewDeviceSelector("notindex", "notuuid", "notpcibusid").Selected(device))
require.True(t, NewDeviceSelector("index").Selected(device))
require.True(t, NewDeviceSelector("notindex", "index").Selected(device))
require.True(t, NewDeviceSelector("uuid").Selected(device))
require.True(t, NewDeviceSelector("notuuid", "uuid").Selected(device))
require.True(t, NewDeviceSelector("pcibusid").Selected(device))
require.True(t, NewDeviceSelector("notpcibusid", "pcibusid").Selected(device))
require.True(t, NewDeviceSelector("index", "uuid", "pcibusid").Selected(device))
}

90
pkg/filter/devices.go Normal file
View File

@@ -0,0 +1,90 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"fmt"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
log "github.com/sirupsen/logrus"
)
const (
visibleDevicesAll = "all"
visibleDevicesNone = "none"
visibleDevicesVoid = "void"
)
type devices struct {
discover.Discover
logger *log.Logger
selector Selector
}
var _ discover.Discover = (*devices)(nil)
// NewSelectDevicesFrom creates a filter that selects devices based on the value of the
// visible devices string.
func NewSelectDevicesFrom(d discover.Discover, visibleDevices string, env EnvLookup) discover.Discover {
return NewSelectDevicesFromWithLogger(log.StandardLogger(), d, visibleDevices, env)
}
// NewSelectDevicesFromWithLogger creates a filter as for NewSelectDevicesFrom with the
// specified logger.
func NewSelectDevicesFromWithLogger(logger *log.Logger, d discover.Discover, visibleDevices string, env EnvLookup) discover.Discover {
if visibleDevices == "" || visibleDevices == visibleDevicesNone || visibleDevices == visibleDevicesVoid {
return &discover.None{}
}
var visibleDeviceSelector Selector
if visibleDevices == visibleDevicesAll {
visibleDeviceSelector = StandardDevice()
} else {
visibleDeviceSelector = All(StandardDevice(), NewDeviceSelector(strings.Split(visibleDevices, ",")...))
}
controlDeviceIds := getControlDeviceIDsFromEnvWithLogger(logger, env)
controlDeviceSelector := All(ControlDevice(), NewDeviceSelector(controlDeviceIds...))
vd := devices{
Discover: d,
logger: logger,
selector: Any(visibleDeviceSelector, controlDeviceSelector),
}
return &vd
}
// Devices returns the list of selected devices after filtering based on the
// configured selector
func (d devices) Devices() ([]discover.Device, error) {
devices, err := d.Discover.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
var selected []discover.Device
for _, di := range devices {
if d.selector.Selected(di) {
d.logger.Infof("selecting device=%v", di)
selected = append(selected, di)
}
}
return selected, nil
}

104
pkg/filter/devices_test.go Normal file
View File

@@ -0,0 +1,104 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
log "github.com/sirupsen/logrus"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestConstructor(t *testing.T) {
logger, _ := testlog.NewNullLogger()
device0 := discover.Device{
Index: "0",
UUID: "0",
PCIBusID: discover.PCIBusID("0"),
}
device1 := discover.Device{
Index: "1",
UUID: "1",
PCIBusID: discover.PCIBusID("1"),
}
device2 := discover.Device{
Index: "2",
UUID: "2",
PCIBusID: discover.PCIBusID("2"),
}
device3 := discover.Device{
Index: "3",
UUID: "3",
PCIBusID: discover.PCIBusID("3"),
}
controlDevice := discover.Device{
UUID: "CONTROL",
}
mockDevices := []discover.Device{
device0,
device1,
device2,
device3,
controlDevice,
}
d := discoverMock{
devices: mockDevices,
}
var ok bool
withDefaultLogger, ok := NewSelectDevicesFrom(d, "all", nil).(*devices)
require.True(t, ok)
require.Same(t, log.StandardLogger(), withDefaultLogger.logger)
_, ok = NewSelectDevicesFromWithLogger(logger, d, "", nil).(*discover.None)
require.True(t, ok)
_, ok = NewSelectDevicesFromWithLogger(logger, d, "void", nil).(*discover.None)
require.True(t, ok)
_, ok = NewSelectDevicesFromWithLogger(logger, d, "none", nil).(*discover.None)
require.True(t, ok)
all, ok := NewSelectDevicesFromWithLogger(logger, d, "all", nil).(*devices)
require.True(t, ok)
devs, err := all.Devices()
require.NoError(t, err)
require.ElementsMatch(t, mockDevices, devs)
f, ok := NewSelectDevicesFromWithLogger(logger, d, "0", nil).(*devices)
require.True(t, ok)
devs, err = f.Devices()
require.NoError(t, err)
require.Len(t, devs, 2)
require.ElementsMatch(t, devs, []discover.Device{device0, controlDevice})
f, ok = NewSelectDevicesFromWithLogger(logger, d, "0,2", nil).(*devices)
require.True(t, ok)
devs, err = f.Devices()
require.NoError(t, err)
require.Len(t, devs, 3)
require.ElementsMatch(t, devs, []discover.Device{device0, device2, controlDevice})
}

26
pkg/filter/filter.go Normal file
View File

@@ -0,0 +1,26 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
//go:generate moq -stub -out filter_mock.go . EnvLookup
// EnvLookup defines an interface that supports the LookupEnv function for getting
// environment variable values.
// TODO: This belongs in a different package
type EnvLookup interface {
LookupEnv(string) (string, bool)
}

77
pkg/filter/filter_mock.go Normal file
View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package filter
import (
"sync"
)
// Ensure, that EnvLookupMock does implement EnvLookup.
// If this is not the case, regenerate this file with moq.
var _ EnvLookup = &EnvLookupMock{}
// EnvLookupMock is a mock implementation of EnvLookup.
//
// func TestSomethingThatUsesEnvLookup(t *testing.T) {
//
// // make and configure a mocked EnvLookup
// mockedEnvLookup := &EnvLookupMock{
// LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method")
// },
// }
//
// // use mockedEnvLookup in code that requires EnvLookup
// // and then make assertions.
//
// }
type EnvLookupMock struct {
// LookupEnvFunc mocks the LookupEnv method.
LookupEnvFunc func(s string) (string, bool)
// calls tracks calls to the methods.
calls struct {
// LookupEnv holds details about calls to the LookupEnv method.
LookupEnv []struct {
// S is the s argument value.
S string
}
}
lockLookupEnv sync.RWMutex
}
// LookupEnv calls LookupEnvFunc.
func (mock *EnvLookupMock) LookupEnv(s string) (string, bool) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLookupEnv.Lock()
mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo)
mock.lockLookupEnv.Unlock()
if mock.LookupEnvFunc == nil {
var (
sOut string
bOut bool
)
return sOut, bOut
}
return mock.LookupEnvFunc(s)
}
// LookupEnvCalls gets all the calls that were made to LookupEnv.
// Check the length with:
// len(mockedEnvLookup.LookupEnvCalls())
func (mock *EnvLookupMock) LookupEnvCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLookupEnv.RLock()
calls = mock.calls.LookupEnv
mock.lockLookupEnv.RUnlock()
return calls
}

View File

@@ -0,0 +1,107 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
log "github.com/sirupsen/logrus"
)
const (
devicesAll = "all"
)
type controlDevices struct {
discover.Discover
logger *log.Logger
selector Selector
}
var _ discover.Discover = (*controlDevices)(nil)
// NewControlDevicesFrom creates a filter that selects devices based on the value of the
// visible devices string.
func NewControlDevicesFrom(d discover.Discover, env EnvLookup) Selector {
return NewControlDevicesFromWithLogger(log.StandardLogger(), d, env)
}
// NewControlDevicesFromWithLogger creates a filter as for NewControlDevicesFrom with the
// specified logger.
func NewControlDevicesFromWithLogger(logger *log.Logger, d discover.Discover, env EnvLookup) Selector {
controlDevices := getControlDeviceIDsFromEnvWithLogger(logger, env)
return NewDeviceSelector(controlDevices...)
}
type controlDevice struct{}
// ControlDevice returns a selector for control devices
func ControlDevice() Selector {
return &controlDevice{}
}
// Selected returns true for a controll devices and false for standard devices. A control device
// has an empty index and PCI bus ID and a non-empty UUID.
func (s controlDevice) Selected(device discover.Device) bool {
if device.Index != "" {
return false
}
if device.PCIBusID != "" {
return false
}
if device.UUID == "" {
return false
}
return true
}
func getControlDeviceIDsFromEnvWithLogger(logger *log.Logger, env EnvLookup) []string {
controlDevices := []string{discover.ControlDeviceUUID}
migControlDevices := getMIGControlDevicesFromEnvWithLogger(logger, env)
return append(controlDevices, migControlDevices...)
}
func getMIGControlDevicesFromEnvWithLogger(logger *log.Logger, env EnvLookup) []string {
if env == nil {
logger.Debugf("Environment not specified; no MIG Control devices selected")
return []string{}
}
var controlDevices []string
// Add MIG control devices
migEnvUUIDMap := map[string]string{
discover.MIGConfigDeviceUUID: "NVIDIA_MIG_CONFIG_DEVICES",
discover.MIGMonitorDeviceUUID: "NVIDIA_MIG_MONITOR_DEVICES",
}
for uuid, migEnv := range migEnvUUIDMap {
config, exists := env.LookupEnv(migEnv)
if !exists {
logger.Debugf("Envvar %v not set", migEnv)
continue
}
if config == devicesAll {
logger.Infof("Found %v=%v; selecting MIG %v devices", migEnv, config, uuid)
controlDevices = append(controlDevices, uuid)
} else {
logger.Debugf("Found %v=%v; Skipping MIG %v devices (%v != %v)", migEnv, config, uuid, config, devicesAll)
}
}
return controlDevices
}

View File

@@ -0,0 +1,123 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"fmt"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestControlDevice(t *testing.T) {
control := ControlDevice()
pcibusID := discover.PCIBusID("pcibusid")
device := discover.Device{
Index: "index",
UUID: "uuid",
PCIBusID: pcibusID,
}
require.False(t, control.Selected(device))
require.False(t, control.Selected(
discover.Device{UUID: "uuid", PCIBusID: pcibusID},
))
require.False(t, control.Selected(
discover.Device{Index: "index", PCIBusID: pcibusID},
))
require.False(t, control.Selected(
discover.Device{Index: "index", UUID: "uuid"},
))
require.False(t, control.Selected(discover.Device{}))
require.True(t, control.Selected(discover.Device{UUID: "uuid"}))
}
func TestGetControlDevicesFromEnv(t *testing.T) {
testCases := []struct {
name string
env map[string]string
expectedIds []string
}{
{
name: "nil environment",
env: nil,
expectedIds: []string{"CONTROL"},
},
{
name: "empty environment",
env: map[string]string{},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG monitor blank",
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": ""},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG config blank",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": ""},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG monitor not all",
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": "not-all"},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG config not all",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "not-all"},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG monitor all",
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": "all"},
expectedIds: []string{"CONTROL", "MONITOR"},
},
{
name: "MIG config all",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "all"},
expectedIds: []string{"CONTROL", "CONFIG"},
},
{
name: "MIG config and monitor all",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "all", "NVIDIA_MIG_MONITOR_DEVICES": "all"},
expectedIds: []string{"CONTROL", "CONFIG", "MONITOR"},
},
}
for i, tc := range testCases {
logger, _ := testlog.NewNullLogger()
t.Run(fmt.Sprintf("%d: %s", i, tc.name), func(t *testing.T) {
deviceIDs := getControlDeviceIDsFromEnvWithLogger(logger, &EnvLookupMock{
LookupEnvFunc: func(s string) (string, bool) {
value, exists := tc.env[s]
return value, exists
},
})
require.ElementsMatch(t, tc.expectedIds, deviceIDs)
})
}
}

View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
type standardDevice struct{}
// StandardDevice returns a selector for regular (non-control) devices
func StandardDevice() Selector {
return &standardDevice{}
}
// Selected returns true for a standard device and false for controll devices. A regular device
// is expected to have an index, uuid, and PCI bus ID.
func (s standardDevice) Selected(device discover.Device) bool {
if device.Index == "" {
return false
}
if device.PCIBusID == "" {
return false
}
if device.UUID == "" {
return false
}
return true
}

View File

@@ -0,0 +1,52 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/stretchr/testify/require"
)
func TestStandardDevice(t *testing.T) {
standard := StandardDevice()
pcibusID := discover.PCIBusID("pcibusid")
device := discover.Device{
Index: "index",
UUID: "uuid",
PCIBusID: pcibusID,
}
require.True(t, standard.Selected(device))
require.False(t, standard.Selected(
discover.Device{UUID: "uuid", PCIBusID: pcibusID},
))
require.False(t, standard.Selected(
discover.Device{Index: "index", PCIBusID: pcibusID},
))
require.False(t, standard.Selected(
discover.Device{Index: "index", UUID: "uuid"},
))
require.False(t, standard.Selected(discover.Device{}))
}

27
pkg/filter/selector.go Normal file
View File

@@ -0,0 +1,27 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
//go:generate moq -stub -out selector_mock.go . Selector
// Selector defines an interface for determining whether a specfied Device is selected
// by a particular configuration.
type Selector interface {
Selected(discover.Device) bool
}

View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package filter
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"sync"
)
// Ensure, that SelectorMock does implement Selector.
// If this is not the case, regenerate this file with moq.
var _ Selector = &SelectorMock{}
// SelectorMock is a mock implementation of Selector.
//
// func TestSomethingThatUsesSelector(t *testing.T) {
//
// // make and configure a mocked Selector
// mockedSelector := &SelectorMock{
// SelectedFunc: func(device discover.Device) bool {
// panic("mock out the Selected method")
// },
// }
//
// // use mockedSelector in code that requires Selector
// // and then make assertions.
//
// }
type SelectorMock struct {
// SelectedFunc mocks the Selected method.
SelectedFunc func(device discover.Device) bool
// calls tracks calls to the methods.
calls struct {
// Selected holds details about calls to the Selected method.
Selected []struct {
// Device is the device argument value.
Device discover.Device
}
}
lockSelected sync.RWMutex
}
// Selected calls SelectedFunc.
func (mock *SelectorMock) Selected(device discover.Device) bool {
callInfo := struct {
Device discover.Device
}{
Device: device,
}
mock.lockSelected.Lock()
mock.calls.Selected = append(mock.calls.Selected, callInfo)
mock.lockSelected.Unlock()
if mock.SelectedFunc == nil {
var (
bOut bool
)
return bOut
}
return mock.SelectedFunc(device)
}
// SelectedCalls gets all the calls that were made to Selected.
// Check the length with:
// len(mockedSelector.SelectedCalls())
func (mock *SelectorMock) SelectedCalls() []struct {
Device discover.Device
} {
var calls []struct {
Device discover.Device
}
mock.lockSelected.RLock()
calls = mock.calls.Selected
mock.lockSelected.RUnlock()
return calls
}

118
pkg/modify/device.go Normal file
View File

@@ -0,0 +1,118 @@
package modify
import (
"fmt"
"os"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
)
// Device is an alias to discover.Device that allows for addition of a Modify method
type Device struct {
logger *log.Logger
discover.Device
}
// ProcMount is an alias to discover.Mount that allows for the addition of a Modify method for
// proc paths associated with devices
type ProcMount struct {
logger *log.Logger
discover.ProcPath
}
var _ Modifier = (*Device)(nil)
var _ Modifier = (*ProcMount)(nil)
// Modify applies the modifications required by a Device to the specified OCI specification
func (d Device) Modify(spec oci.Spec) error {
for _, dn := range d.DeviceNodes {
mi := deviceNode{
logger: d.logger,
DeviceNode: dn,
}
err := mi.Modify(spec)
if err != nil {
return fmt.Errorf("could not inject device node %v: %v", dn, err)
}
}
for _, p := range d.ProcPaths {
mi := ProcMount{
logger: d.logger,
ProcPath: p,
}
err := mi.Modify(spec)
if err != nil {
return fmt.Errorf("could not inject proc path %v: %v", p, err)
}
}
return nil
}
type deviceNode struct {
logger *log.Logger
discover.DeviceNode
}
func (d deviceNode) Modify(spec oci.Spec) error {
return spec.Modify(d.specModifier)
}
func (d deviceNode) specModifier(spec *specs.Spec) error {
if spec.Linux == nil {
d.logger.Debugf("Initializing spec.Linux")
spec.Linux = &specs.Linux{}
}
if spec.Linux.Resources == nil {
d.logger.Debugf("Initializing spec.LinuxResources")
spec.Linux.Resources = &specs.LinuxResources{}
}
// TODO: These need to be configurable
deviceFileMode := os.FileMode(8630)
deviceUID := uint32(0)
deviceGID := uint32(0)
deviceMajor := int64(d.Major)
deviceMinor := int64(d.Minor)
d.logger.Infof("Adding device %v", d.Path)
ociDevice := specs.LinuxDevice{
Path: string(d.Path),
Type: "c",
Major: deviceMajor,
Minor: deviceMinor,
FileMode: &deviceFileMode,
UID: &deviceUID,
GID: &deviceGID,
}
spec.Linux.Devices = append(spec.Linux.Devices, ociDevice)
ociDeviceCgroup := specs.LinuxDeviceCgroup{
Allow: true,
Type: "c",
Major: &deviceMajor,
Minor: &deviceMinor,
Access: "rwm",
}
// TODO: We have to handle the case where we are updating the cgroups for multiple devices
// leading to duplicates in the spec
spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, ociDeviceCgroup)
return nil
}
// Modify applies the modifications required for a Mount to the specified OCI specification
func (m ProcMount) Modify(spec oci.Spec) error {
return spec.Modify(m.specModifier)
}
func (m ProcMount) specModifier(spec *specs.Spec) error {
m.logger.Infof("Mounting read-only proc path %v", m.ProcPath)
spec.Linux.ReadonlyPaths = append(spec.Linux.ReadonlyPaths, string(m.ProcPath))
return nil
}

155
pkg/modify/discover.go Normal file
View File

@@ -0,0 +1,155 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package modify
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
log "github.com/sirupsen/logrus"
)
type discoverModifier struct {
logger *log.Logger
discover discover.Discover
root string
bundleDir string
}
var _ Modifier = (*discoverModifier)(nil)
// NewModifierFor creates a Modifier that can be used to apply the modifications to an OCI specification
// required by the specified Discover instance.
func NewModifierFor(discover discover.Discover, root string, bundleDir string) Modifier {
return NewModifierWithLoggerFor(log.StandardLogger(), discover, root, bundleDir)
}
// NewModifierWithLoggerFor creates a Modifier that can be used to apply the modifications to an OCI specification
// required by the specified Discover instance.
func NewModifierWithLoggerFor(logger *log.Logger, discover discover.Discover, root string, bundleDir string) Modifier {
m := discoverModifier{
logger: logger,
discover: discover,
root: root,
bundleDir: bundleDir,
}
return &m
}
// Modify applies the modifications for the discovered devices, mounts, etc. to the specified
// OCI spec.
func (m discoverModifier) Modify(spec oci.Spec) error {
m.logger.Infof("Determining required OCI spec modifications")
modifiers, err := m.modifiers()
if err != nil {
return fmt.Errorf("error constructing modifiers: %v", err)
}
m.logger.Infof("Applying %v modifications", len(modifiers))
for _, mi := range modifiers {
err := mi.Modify(spec)
if err != nil {
return fmt.Errorf("could not apply modifier %v: %v", mi, err)
}
}
return nil
}
func (m discoverModifier) modifiers() ([]Modifier, error) {
var modifiers []Modifier
deviceModifiers, err := m.deviceModifiers()
if err != nil {
return nil, err
}
modifiers = append(modifiers, deviceModifiers...)
mountModifiers, err := m.mountModifiers()
if err != nil {
return nil, err
}
modifiers = append(modifiers, mountModifiers...)
hookModifiers, err := m.hookModifiers()
if err != nil {
return nil, err
}
modifiers = append(modifiers, hookModifiers...)
return modifiers, nil
}
func (m discoverModifier) deviceModifiers() ([]Modifier, error) {
var modifiers []Modifier
devices, err := m.discover.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
for _, d := range devices {
m := Device{
logger: m.logger,
Device: d,
}
modifiers = append(modifiers, m)
}
return modifiers, nil
}
func (m discoverModifier) mountModifiers() ([]Modifier, error) {
var modifiers []Modifier
mounts, err := m.discover.Mounts()
if err != nil {
return nil, fmt.Errorf("error discovering mounts: %v", err)
}
for _, mi := range mounts {
mm := Mount{
logger: m.logger,
Mount: mi,
root: m.root,
}
modifiers = append(modifiers, mm)
}
return modifiers, nil
}
func (m discoverModifier) hookModifiers() ([]Modifier, error) {
var modifiers []Modifier
hooks, err := m.discover.Hooks()
if err != nil {
return nil, fmt.Errorf("error discovering hooks: %v", err)
}
for _, h := range hooks {
m := Hook{
logger: m.logger,
Hook: h,
bundleDir: m.bundleDir,
}
modifiers = append(modifiers, m)
}
return modifiers, nil
}

81
pkg/modify/hooks.go Normal file
View File

@@ -0,0 +1,81 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package modify
import (
"fmt"
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
)
// Hook is an alias to discover.Hook that allows for addition of a Modify method
type Hook struct {
logger *log.Logger
discover.Hook
bundleDir string
}
var _ Modifier = (*Hook)(nil)
// Modify applies the modifications required by a Hook to the specified OCI specification
func (h Hook) Modify(spec oci.Spec) error {
return spec.Modify(h.specModifier)
}
func (h Hook) specModifier(spec *specs.Spec) error {
if spec.Hooks == nil {
h.logger.Debugf("Initializing spec.Hooks")
spec.Hooks = &specs.Hooks{}
}
// TODO: This is duplicated in the hook specification
const rootPattern = "@Root.Path@"
rootPath := spec.Root.Path
if !filepath.IsAbs(rootPath) {
rootPath = filepath.Join(h.bundleDir, rootPath)
}
var args []string
for _, a := range h.Args {
if strings.Contains(a, rootPattern) {
args = append(args, strings.ReplaceAll(a, rootPattern, rootPath))
continue
}
args = append(args, a)
}
specHook := specs.Hook{
Path: h.Path,
Args: args,
}
h.logger.Infof("Adding %v hook %+v", h.HookName, specHook)
switch h.HookName {
case "create-container":
spec.Hooks.CreateContainer = append(spec.Hooks.CreateContainer, specHook)
default:
return fmt.Errorf("unexpected hook name: %v", h.HookName)
}
return nil
}

28
pkg/modify/modify.go Normal file
View File

@@ -0,0 +1,28 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package modify
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
)
//go:generate moq -stub -out modify_mock.go . Modifier
// Modifier defines an interface for modifying an OCI Specification.
type Modifier interface {
Modify(oci.Spec) error
}

77
pkg/modify/modify_mock.go Normal file
View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package modify
import (
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
"sync"
)
// Ensure, that ModifierMock does implement Modifier.
// If this is not the case, regenerate this file with moq.
var _ Modifier = &ModifierMock{}
// ModifierMock is a mock implementation of Modifier.
//
// func TestSomethingThatUsesModifier(t *testing.T) {
//
// // make and configure a mocked Modifier
// mockedModifier := &ModifierMock{
// ModifyFunc: func(spec oci.Spec) error {
// panic("mock out the Modify method")
// },
// }
//
// // use mockedModifier in code that requires Modifier
// // and then make assertions.
//
// }
type ModifierMock struct {
// ModifyFunc mocks the Modify method.
ModifyFunc func(spec oci.Spec) error
// calls tracks calls to the methods.
calls struct {
// Modify holds details about calls to the Modify method.
Modify []struct {
// Spec is the spec argument value.
Spec oci.Spec
}
}
lockModify sync.RWMutex
}
// Modify calls ModifyFunc.
func (mock *ModifierMock) Modify(spec oci.Spec) error {
callInfo := struct {
Spec oci.Spec
}{
Spec: spec,
}
mock.lockModify.Lock()
mock.calls.Modify = append(mock.calls.Modify, callInfo)
mock.lockModify.Unlock()
if mock.ModifyFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ModifyFunc(spec)
}
// ModifyCalls gets all the calls that were made to Modify.
// Check the length with:
// len(mockedModifier.ModifyCalls())
func (mock *ModifierMock) ModifyCalls() []struct {
Spec oci.Spec
} {
var calls []struct {
Spec oci.Spec
}
mock.lockModify.RLock()
calls = mock.calls.Modify
mock.lockModify.RUnlock()
return calls
}

53
pkg/modify/mounts.go Normal file
View File

@@ -0,0 +1,53 @@
package modify
import (
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/discover"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
)
// Mount is an alias to discover.Mount that allows for addition of a Modify method
type Mount struct {
logger *log.Logger
discover.Mount
root string
}
var _ Modifier = (*Mount)(nil)
// Modify applies the modifications required for a Mount to the specified OCI specification
func (d Mount) Modify(spec oci.Spec) error {
return spec.Modify(d.specModifier)
}
// TODO: We need to ensure that we are correctly mounting the proc paths
// Also — Im not sure how this is done, but we will need a new tempfs mounted at /proc/driver/nvidia/ underneath which all of these other mounted directories get put
// Maybe this?
// https://github.com/opencontainers/runtime-spec/blob/master/specs-go/config.go#L175 (edited)
// specs-go/config.go:175
// MaskedPaths []string `json:"maskedPaths,omitempty"`
// <https://github.com/opencontainers/runtime-spec|opencontainers/runtime-spec>opencontainers/runtime-spec | Added by GitHub
// 13:53
// Proabably, given…
// https://github.com/opencontainers/runtime-spec/blob/master/config-linux.md#masked-paths (edited)
// TODO: We can try masking all of /proc/driver/nvidia and then mounting the paths read-only
func (d Mount) specModifier(spec *specs.Spec) error {
source := d.Path
destination := strings.TrimPrefix(d.Path, d.root)
d.logger.Infof("Mounting %v -> %v", source, destination)
mount := specs.Mount{
Destination: destination,
Source: source,
Type: "bind",
Options: []string{
"rbind",
"rprivate",
},
}
spec.Mounts = append(spec.Mounts, mount)
return nil
}

135
pkg/oci/args.go Normal file
View File

@@ -0,0 +1,135 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os"
"path/filepath"
"strings"
)
const (
specFileName = "config.json"
)
// GetBundleDir returns the bundle directory or default depending on the
// supplied command line arguments.
func GetBundleDir(args []string) (string, error) {
bundleDir, err := GetBundleDirFromArgs(args)
if err != nil {
return "", fmt.Errorf("error getting bundle dir from args: %v", err)
}
if bundleDir != "" {
return bundleDir, nil
}
defaultBundleDir, err := GetDefaultBundleDir()
if err != nil {
return "", fmt.Errorf("error getting default bundle dir: %v", err)
}
return defaultBundleDir, nil
}
// GetBundleDirFromArgs checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
// The following are supported:
// --bundle{{SEP}}BUNDLE_PATH
// -bundle{{SEP}}BUNDLE_PATH
// -b{{SEP}}BUNDLE_PATH
// where {{SEP}} is either ' ' or '='
func GetBundleDirFromArgs(args []string) (string, error) {
var bundleDir string
for i := 0; i < len(args); i++ {
param := args[i]
parts := strings.SplitN(param, "=", 2)
if !IsBundleFlag(parts[0]) {
continue
}
// The flag has the format --bundle=/path
if len(parts) == 2 {
bundleDir = parts[1]
continue
}
// The flag has the format --bundle /path
if i+1 < len(args) {
bundleDir = args[i+1]
i++
continue
}
// --bundle / -b was the last element of args
return "", fmt.Errorf("bundle option requires an argument")
}
return bundleDir, nil
}
// GetDefaultBundleDir returns the bundle directory that is to be used if no alternative is
// specified via the command line, for example.
func GetDefaultBundleDir() (string, error) {
workingDirectory, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("error getting working directory: %v", err)
}
return workingDirectory, nil
}
// GetSpecFilePath returns the expected path to the OCI specification file for the given
// bundle directory.
func GetSpecFilePath(bundleDir string) string {
specFilePath := filepath.Join(bundleDir, specFileName)
return specFilePath
}
// IsBundleFlag is a helper function that checks wither the specified argument represents
// a bundle flag (--bundle or -b)
func IsBundleFlag(arg string) bool {
if !strings.HasPrefix(arg, "-") {
return false
}
trimmed := strings.TrimLeft(arg, "-")
return trimmed == "b" || trimmed == "bundle"
}
// HasCreateSubcommand checks the supplied arguments for a 'create' subcommand
func HasCreateSubcommand(args []string) bool {
var previousWasBundle bool
for _, a := range args {
// We check for '--bundle create' explicitly to ensure that we
// don't inadvertently trigger a modification if the bundle directory
// is specified as `create`
if !previousWasBundle && IsBundleFlag(a) {
previousWasBundle = true
continue
}
if !previousWasBundle && a == "create" {
return true
}
previousWasBundle = false
}
return false
}

198
pkg/oci/args_test.go Normal file
View File

@@ -0,0 +1,198 @@
package oci
import (
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetBundleDir(t *testing.T) {
defaultBundleDir, err := os.Getwd()
require.NoError(t, err)
type expected struct {
bundle string
isError bool
}
testCases := []struct {
argv []string
expected expected
}{
{
argv: []string{},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"create"},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"--bundle"},
expected: expected{
isError: true,
},
},
{
argv: []string{"-b"},
expected: expected{
isError: true,
},
},
{
argv: []string{"--bundle", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"--not-bundle", "/foo/bar"},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"--"},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"-bundle", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"--bundle=/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"-b=/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"-b=/foo/=bar"},
expected: expected{
bundle: "/foo/=bar",
},
},
{
argv: []string{"-b", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"create", "-b", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"-b", "create", "create"},
expected: expected{
bundle: "create",
},
},
{
argv: []string{"-b=create", "create"},
expected: expected{
bundle: "create",
},
},
{
argv: []string{"-b", "create"},
expected: expected{
bundle: "create",
},
},
}
for i, tc := range testCases {
bundle, err := GetBundleDir(tc.argv)
if tc.expected.isError {
require.Errorf(t, err, "%d: %v", i, tc)
} else {
require.NoErrorf(t, err, "%d: %v", i, tc)
}
require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc)
}
}
func TestGetDefaultBundleDir(t *testing.T) {
defaultBundleDir, err := os.Getwd()
require.NoError(t, err)
bundleDir, err := GetDefaultBundleDir()
require.NoError(t, err)
require.Equal(t, defaultBundleDir, bundleDir)
}
func TestGetSpecFilePathAppendsFilename(t *testing.T) {
testCases := []struct {
bundleDir string
expected string
}{
{
bundleDir: "",
expected: "config.json",
},
{
bundleDir: "/not/empty/",
expected: "/not/empty/config.json",
},
{
bundleDir: "not/absolute",
expected: "not/absolute/config.json",
},
}
for i, tc := range testCases {
specPath := GetSpecFilePath(tc.bundleDir)
require.Equalf(t, tc.expected, specPath, "%d: %v", i, tc)
}
}
func TestHasCreateSubcommand(t *testing.T) {
testCases := []struct {
args []string
shouldModify bool
}{
{
shouldModify: false,
},
{
args: []string{"create"},
shouldModify: true,
},
{
args: []string{"--bundle=create"},
shouldModify: false,
},
{
args: []string{"--bundle", "create"},
shouldModify: false,
},
{
args: []string{"create"},
shouldModify: true,
},
}
for i, tc := range testCases {
require.Equal(t, tc.shouldModify, HasCreateSubcommand(tc.args), "%d: %v", i, tc)
}
}

25
pkg/oci/runtime.go Normal file
View File

@@ -0,0 +1,25 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
//go:generate moq -stub -out runtime_mock.go . Runtime
// Runtime is an interface for a runtime shim. The Exec method accepts a list
// of command line arguments, and returns an error / nil.
type Runtime interface {
Exec([]string) error
}

View File

@@ -0,0 +1,61 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
// The executable specified is taken from the list of supplied candidates, with the first match
// present in the PATH being selected.
func NewLowLevelRuntime(candidates ...string) (Runtime, error) {
return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...)
}
// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger.
func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) {
runtimePath, err := findRuntime(candidates)
if err != nil {
return nil, fmt.Errorf("error locating runtime: %v", err)
}
return NewRuntimeForPathWithLogger(logger, runtimePath)
}
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.
// The absolute path to the first match is returned.
func findRuntime(candidates []string) (string, error) {
if len(candidates) == 0 {
return "", fmt.Errorf("at least one runtime candidate must be specified")
}
for _, candidate := range candidates {
log.Infof("Looking for runtime binary '%v'", candidate)
runcPath, err := exec.LookPath(candidate)
if err == nil {
log.Infof("Found runtime binary '%v'", runcPath)
return runcPath, nil
}
log.Warnf("Runtime binary '%v' not found: %v", candidate, err)
}
return "", fmt.Errorf("no runtime binary found from candidate list: %v", candidates)
}

76
pkg/oci/runtime_mock.go Normal file
View File

@@ -0,0 +1,76 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package oci
import (
"sync"
)
// Ensure, that RuntimeMock does implement Runtime.
// If this is not the case, regenerate this file with moq.
var _ Runtime = &RuntimeMock{}
// RuntimeMock is a mock implementation of Runtime.
//
// func TestSomethingThatUsesRuntime(t *testing.T) {
//
// // make and configure a mocked Runtime
// mockedRuntime := &RuntimeMock{
// ExecFunc: func(strings []string) error {
// panic("mock out the Exec method")
// },
// }
//
// // use mockedRuntime in code that requires Runtime
// // and then make assertions.
//
// }
type RuntimeMock struct {
// ExecFunc mocks the Exec method.
ExecFunc func(strings []string) error
// calls tracks calls to the methods.
calls struct {
// Exec holds details about calls to the Exec method.
Exec []struct {
// Strings is the strings argument value.
Strings []string
}
}
lockExec sync.RWMutex
}
// Exec calls ExecFunc.
func (mock *RuntimeMock) Exec(strings []string) error {
callInfo := struct {
Strings []string
}{
Strings: strings,
}
mock.lockExec.Lock()
mock.calls.Exec = append(mock.calls.Exec, callInfo)
mock.lockExec.Unlock()
if mock.ExecFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ExecFunc(strings)
}
// ExecCalls gets all the calls that were made to Exec.
// Check the length with:
// len(mockedRuntime.ExecCalls())
func (mock *RuntimeMock) ExecCalls() []struct {
Strings []string
} {
var calls []struct {
Strings []string
}
mock.lockExec.RLock()
calls = mock.calls.Exec
mock.lockExec.RUnlock()
return calls
}

70
pkg/oci/runtime_path.go Normal file
View File

@@ -0,0 +1,70 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
)
// pathRuntime wraps the path that a binary and defines the semanitcs for how to exec into it.
// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the
// Runtime internface.
type pathRuntime struct {
logger *log.Logger
path string
execRuntime Runtime
}
var _ Runtime = (*pathRuntime)(nil)
// NewRuntimeForPath creates a Runtime for the specified path with the standard logger
func NewRuntimeForPath(path string) (Runtime, error) {
return NewRuntimeForPathWithLogger(log.StandardLogger(), path)
}
// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path
func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) {
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
}
if info.IsDir() || info.Mode()&0111 == 0 {
return nil, fmt.Errorf("specified path '%v' is not an executable file", path)
}
shim := pathRuntime{
logger: logger,
path: path,
execRuntime: syscallExec{},
}
return &shim, nil
}
// Exec exces into the binary at the path from the pathRuntime struct, passing it the supplied arguments
// after ensuring that the first argument is the path of the target binary.
func (s pathRuntime) Exec(args []string) error {
runtimeArgs := []string{s.path}
if len(args) > 1 {
runtimeArgs = append(runtimeArgs, args[1:]...)
}
return s.execRuntime.Exec(runtimeArgs)
}

View File

@@ -0,0 +1,97 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestPathRuntimeConstructor(t *testing.T) {
r, err := NewRuntimeForPath("////an/invalid/path")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/tmp")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/dev/null")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/bin/sh")
require.NoError(t, err)
f, ok := r.(*pathRuntime)
require.True(t, ok)
require.Equal(t, "/bin/sh", f.path)
}
func TestPathRuntimeForwardsArgs(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
execRuntimeError error
args []string
}{
{},
{
args: []string{"shouldBeReplaced"},
},
{
args: []string{"shouldBeReplaced", "arg1"},
},
{
execRuntimeError: fmt.Errorf("exec error"),
},
}
for _, tc := range testCases {
mockedRuntime := &RuntimeMock{
ExecFunc: func(strings []string) error {
return tc.execRuntimeError
},
}
r := pathRuntime{
logger: logger,
path: "runtime",
execRuntime: mockedRuntime,
}
err := r.Exec(tc.args)
require.ErrorIs(t, err, tc.execRuntimeError)
calls := mockedRuntime.ExecCalls()
require.Len(t, calls, 1)
numArgs := len(tc.args)
if numArgs == 0 {
numArgs = 1
}
require.Len(t, calls[0].Strings, numArgs)
require.Equal(t, "runtime", calls[0].Strings[0])
if numArgs > 1 {
require.EqualValues(t, tc.args[1:], calls[0].Strings[1:])
}
}
}

View File

@@ -0,0 +1,38 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os"
"syscall"
)
type syscallExec struct{}
var _ Runtime = (*syscallExec)(nil)
func (r syscallExec) Exec(args []string) error {
err := syscall.Exec(args[0], args, os.Environ())
if err != nil {
return fmt.Errorf("could not exec '%v': %v", args[0], err)
}
// syscall.Exec is not expected to return. This is an error state regardless of whether
// err is nil or not.
return fmt.Errorf("unexpected return from exec '%v'", args[0])
}

35
pkg/oci/spec.go Normal file
View File

@@ -0,0 +1,35 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
oci "github.com/opencontainers/runtime-spec/specs-go"
)
// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an
// error. The intention is that the function would modify the spec in-place.
type SpecModifier func(*oci.Spec) error
//go:generate moq -stub -out spec_mock.go . Spec
// Spec defines the operations to be performed on an OCI specification
type Spec interface {
Load() error
Flush() error
Modify(SpecModifier) error
LookupEnv(string) (string, bool)
}

153
pkg/oci/spec_file.go Normal file
View File

@@ -0,0 +1,153 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"encoding/json"
"fmt"
"io"
"os"
"strings"
oci "github.com/opencontainers/runtime-spec/specs-go"
)
type fileSpec struct {
*oci.Spec
path string
}
var _ Spec = (*fileSpec)(nil)
// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the
// application
func NewSpecFromArgs(args []string) (Spec, string, error) {
bundleDir, err := GetBundleDir(args)
if err != nil {
return nil, "", fmt.Errorf("error getting bundle directory: %v", err)
}
ociSpecPath := GetSpecFilePath(bundleDir)
ociSpec := NewSpecFromFile(ociSpecPath)
return ociSpec, bundleDir, nil
}
// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec.
// This can be used to read from the file, modify the spec, and write to the
// same file.
func NewSpecFromFile(filepath string) Spec {
oci := fileSpec{
path: filepath,
}
return &oci
}
// Load reads the contents of an OCI spec from file to be referenced internally.
// The file is opened "read-only"
func (s *fileSpec) Load() error {
specFile, err := os.Open(s.path)
if err != nil {
return fmt.Errorf("error opening OCI specification file: %v", err)
}
defer specFile.Close()
return s.loadFrom(specFile)
}
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
func (s *fileSpec) loadFrom(reader io.Reader) error {
decoder := json.NewDecoder(reader)
var spec oci.Spec
err := decoder.Decode(&spec)
if err != nil {
return fmt.Errorf("error reading OCI specification: %v", err)
}
s.Spec = &spec
return nil
}
// Modify applies the specified SpecModifier to the stored OCI specification.
func (s *fileSpec) Modify(f SpecModifier) error {
if s.Spec == nil {
return fmt.Errorf("no spec loaded for modification")
}
return f(s.Spec)
}
// Flush writes the stored OCI specification to the filepath specifed by the path member.
// The file is truncated upon opening, overwriting any existing contents.
func (s fileSpec) Flush() error {
if s.Spec == nil {
return fmt.Errorf("no OCI specification loaded")
}
specFile, err := os.Create(s.path)
if err != nil {
return fmt.Errorf("error opening OCI specification file: %v", err)
}
defer specFile.Close()
return s.flushTo(specFile)
}
// flushTo writes the stored OCI specification to the specified io.Writer.
func (s fileSpec) flushTo(writer io.Writer) error {
if s.Spec == nil {
return nil
}
encoder := json.NewEncoder(writer)
err := encoder.Encode(s.Spec)
if err != nil {
return fmt.Errorf("error writing OCI specification: %v", err)
}
return nil
}
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
// retrieves the value of the environment variable named
// by the key. If the variable is present in the environment the
// value (which may be empty) is returned and the boolean is true.
// Otherwise the returned value will be empty and the boolean will
// be false.
func (s fileSpec) LookupEnv(key string) (string, bool) {
if s.Spec == nil || s.Spec.Process == nil {
return "", false
}
for _, env := range s.Spec.Process.Env {
if !strings.HasPrefix(env, key) {
continue
}
parts := strings.SplitN(env, "=", 2)
if parts[0] == key {
if len(parts) < 2 {
return "", true
}
return parts[1], true
}
}
return "", false
}

252
pkg/oci/spec_file_test.go Normal file
View File

@@ -0,0 +1,252 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"bytes"
"fmt"
"testing"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require"
)
func TestLookupEnv(t *testing.T) {
const envName = "TEST_ENV"
testCases := []struct {
spec *specs.Spec
expectedValue string
expectedExits bool
}{
{
// nil spec
spec: nil,
expectedValue: "",
expectedExits: false,
},
{
// nil process
spec: &specs.Spec{},
expectedValue: "",
expectedExits: false,
},
{
// nil env
spec: &specs.Spec{
Process: &specs.Process{},
},
expectedValue: "",
expectedExits: false,
},
{
// empty env
spec: &specs.Spec{
Process: &specs.Process{Env: []string{}},
},
expectedValue: "",
expectedExits: false,
},
{
// different env set
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same prefix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same suffix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// set blank
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV="}},
},
expectedValue: "",
expectedExits: true,
},
{
// set no-equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV"}},
},
expectedValue: "",
expectedExits: true,
},
{
// set value
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
},
expectedValue: "something",
expectedExits: true,
},
{
// set with equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
},
expectedValue: "something=somethingelse",
expectedExits: true,
},
}
for i, tc := range testCases {
spec := fileSpec{
Spec: tc.spec,
}
value, exists := spec.LookupEnv(envName)
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
}
}
func TestLoadFrom(t *testing.T) {
testCases := []struct {
contents []byte
isError bool
spec *specs.Spec
}{
{
contents: []byte{},
isError: true,
},
{
contents: []byte("{}"),
isError: false,
spec: &specs.Spec{},
},
}
for i, tc := range testCases {
spec := fileSpec{}
err := spec.loadFrom(bytes.NewReader(tc.contents))
if tc.isError {
require.Error(t, err, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
}
if tc.spec == nil {
require.Nil(t, spec.Spec, "%d: %v", i, tc)
} else {
require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc)
}
}
}
func TestFlushTo(t *testing.T) {
testCases := []struct {
isError bool
spec *specs.Spec
contents string
}{
{
spec: nil,
},
{
spec: &specs.Spec{},
contents: "{\"ociVersion\":\"\"}\n",
},
}
for i, tc := range testCases {
buffer := bytes.Buffer{}
spec := fileSpec{Spec: tc.spec}
err := spec.flushTo(&buffer)
if tc.isError {
require.Error(t, err, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
}
require.EqualValues(t, tc.contents, buffer.String(), "%d: %v", i, tc)
}
// Add a simple test for a writer that returns an error when writing
spec := fileSpec{Spec: &specs.Spec{}}
err := spec.flushTo(errorWriter{})
require.Error(t, err)
}
func TestModify(t *testing.T) {
testCases := []struct {
spec *specs.Spec
modifierError error
}{
{
spec: nil,
},
{
spec: &specs.Spec{},
},
{
spec: &specs.Spec{},
modifierError: fmt.Errorf("error in modifier"),
},
}
for i, tc := range testCases {
spec := fileSpec{Spec: tc.spec}
modifier := func(spec *specs.Spec) error {
if tc.modifierError == nil {
spec.Version = "updated"
}
return tc.modifierError
}
err := spec.Modify(modifier)
if tc.spec == nil {
require.Error(t, err, "%d: %v", i, tc)
} else if tc.modifierError != nil {
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
}
}
}
// errorWriter implements the io.Writer interface, always returning an error when
// writing.
type errorWriter struct{}
func (e errorWriter) Write([]byte) (int, error) {
return 0, fmt.Errorf("error writing")
}

201
pkg/oci/spec_mock.go Normal file
View File

@@ -0,0 +1,201 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package oci
import (
"sync"
)
// Ensure, that SpecMock does implement Spec.
// If this is not the case, regenerate this file with moq.
var _ Spec = &SpecMock{}
// SpecMock is a mock implementation of Spec.
//
// func TestSomethingThatUsesSpec(t *testing.T) {
//
// // make and configure a mocked Spec
// mockedSpec := &SpecMock{
// FlushFunc: func() error {
// panic("mock out the Flush method")
// },
// LoadFunc: func() error {
// panic("mock out the Load method")
// },
// LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method")
// },
// ModifyFunc: func(specModifier SpecModifier) error {
// panic("mock out the Modify method")
// },
// }
//
// // use mockedSpec in code that requires Spec
// // and then make assertions.
//
// }
type SpecMock struct {
// FlushFunc mocks the Flush method.
FlushFunc func() error
// LoadFunc mocks the Load method.
LoadFunc func() error
// LookupEnvFunc mocks the LookupEnv method.
LookupEnvFunc func(s string) (string, bool)
// ModifyFunc mocks the Modify method.
ModifyFunc func(specModifier SpecModifier) error
// calls tracks calls to the methods.
calls struct {
// Flush holds details about calls to the Flush method.
Flush []struct {
}
// Load holds details about calls to the Load method.
Load []struct {
}
// LookupEnv holds details about calls to the LookupEnv method.
LookupEnv []struct {
// S is the s argument value.
S string
}
// Modify holds details about calls to the Modify method.
Modify []struct {
// SpecModifier is the specModifier argument value.
SpecModifier SpecModifier
}
}
lockFlush sync.RWMutex
lockLoad sync.RWMutex
lockLookupEnv sync.RWMutex
lockModify sync.RWMutex
}
// Flush calls FlushFunc.
func (mock *SpecMock) Flush() error {
callInfo := struct {
}{}
mock.lockFlush.Lock()
mock.calls.Flush = append(mock.calls.Flush, callInfo)
mock.lockFlush.Unlock()
if mock.FlushFunc == nil {
var (
errOut error
)
return errOut
}
return mock.FlushFunc()
}
// FlushCalls gets all the calls that were made to Flush.
// Check the length with:
// len(mockedSpec.FlushCalls())
func (mock *SpecMock) FlushCalls() []struct {
} {
var calls []struct {
}
mock.lockFlush.RLock()
calls = mock.calls.Flush
mock.lockFlush.RUnlock()
return calls
}
// Load calls LoadFunc.
func (mock *SpecMock) Load() error {
callInfo := struct {
}{}
mock.lockLoad.Lock()
mock.calls.Load = append(mock.calls.Load, callInfo)
mock.lockLoad.Unlock()
if mock.LoadFunc == nil {
var (
errOut error
)
return errOut
}
return mock.LoadFunc()
}
// LoadCalls gets all the calls that were made to Load.
// Check the length with:
// len(mockedSpec.LoadCalls())
func (mock *SpecMock) LoadCalls() []struct {
} {
var calls []struct {
}
mock.lockLoad.RLock()
calls = mock.calls.Load
mock.lockLoad.RUnlock()
return calls
}
// LookupEnv calls LookupEnvFunc.
func (mock *SpecMock) LookupEnv(s string) (string, bool) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLookupEnv.Lock()
mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo)
mock.lockLookupEnv.Unlock()
if mock.LookupEnvFunc == nil {
var (
sOut string
bOut bool
)
return sOut, bOut
}
return mock.LookupEnvFunc(s)
}
// LookupEnvCalls gets all the calls that were made to LookupEnv.
// Check the length with:
// len(mockedSpec.LookupEnvCalls())
func (mock *SpecMock) LookupEnvCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLookupEnv.RLock()
calls = mock.calls.LookupEnv
mock.lockLookupEnv.RUnlock()
return calls
}
// Modify calls ModifyFunc.
func (mock *SpecMock) Modify(specModifier SpecModifier) error {
callInfo := struct {
SpecModifier SpecModifier
}{
SpecModifier: specModifier,
}
mock.lockModify.Lock()
mock.calls.Modify = append(mock.calls.Modify, callInfo)
mock.lockModify.Unlock()
if mock.ModifyFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ModifyFunc(specModifier)
}
// ModifyCalls gets all the calls that were made to Modify.
// Check the length with:
// len(mockedSpec.ModifyCalls())
func (mock *SpecMock) ModifyCalls() []struct {
SpecModifier SpecModifier
} {
var calls []struct {
SpecModifier SpecModifier
}
mock.lockModify.RLock()
calls = mock.calls.Modify
mock.lockModify.RUnlock()
return calls
}

View File

@@ -0,0 +1,82 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/modify"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
log "github.com/sirupsen/logrus"
)
type modifyingRuntimeWrapper struct {
logger *log.Logger
runtime oci.Runtime
ociSpec oci.Spec
modifier modify.Modifier
}
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
// NewModifyingRuntimeWrapperWithLogger creates a runtime wrapper that applies the specified modifier to the OCI specification
// before invoking the wrapped runtime.
func NewModifyingRuntimeWrapperWithLogger(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier modify.Modifier) oci.Runtime {
rt := modifyingRuntimeWrapper{
logger: logger,
runtime: runtime,
ociSpec: spec,
modifier: modifier,
}
return &rt
}
// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing
// into the wrapped runtime.
func (r *modifyingRuntimeWrapper) Exec(args []string) error {
if oci.HasCreateSubcommand(args) {
err := r.modify()
if err != nil {
return fmt.Errorf("could not apply required modification to OCI specification: %v", err)
}
r.logger.Infof("Applied required modification to OCI specification")
} else {
r.logger.Infof("No modification of OCI specification required")
}
r.logger.Infof("Forwarding command to runtime")
return r.runtime.Exec(args)
}
// modify loads, modifies, and flushes the OCI specification using the defined Modifier
func (r *modifyingRuntimeWrapper) modify() error {
err := r.ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification for modification: %v", err)
}
err = r.modifier.Modify(r.ociSpec)
if err != nil {
return fmt.Errorf("error modifying OCI spec: %v", err)
}
err = r.ociSpec.Flush()
if err != nil {
return fmt.Errorf("error writing modified OCI specification: %v", err)
}
return nil
}

View File

@@ -0,0 +1,160 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/modify"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/oci"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestRuntimeModifier(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
args []string
shouldModify bool
}{
{},
{
args: []string{"create"},
shouldModify: true,
},
}
for _, tc := range testCases {
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec(tc.args)
require.NoError(t, err)
expectedCalls := 0
if tc.shouldModify {
expectedCalls = 1
}
require.Len(t, specMock.LoadCalls(), expectedCalls)
require.Len(t, modifierMock.ModifyCalls(), expectedCalls)
require.Len(t, specMock.FlushCalls(), expectedCalls)
require.Len(t, runtimeMock.ExecCalls(), 1)
}
}
func TestRuntimeModiferWithLoadError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{
LoadFunc: specErrorFunc,
}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 0)
require.Len(t, specMock.FlushCalls(), 0)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func TestRuntimeModiferWithFlushError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{
FlushFunc: specErrorFunc,
}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 1)
require.Len(t, specMock.FlushCalls(), 1)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func TestRuntimeModiferWithModifyError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{}
modifierMock := &modify.ModifierMock{
ModifyFunc: modifierErrorFunc,
}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 1)
require.Len(t, specMock.FlushCalls(), 0)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func specErrorFunc() error {
return fmt.Errorf("error")
}
func modifierErrorFunc(oci.Spec) error {
return fmt.Errorf("error")
}

202
vendor/github.com/NVIDIA/go-nvml/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

75
vendor/github.com/NVIDIA/go-nvml/pkg/dl/dl.go generated vendored Normal file
View File

@@ -0,0 +1,75 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dl
import (
"fmt"
"unsafe"
)
// #cgo LDFLAGS: -ldl
// #include <dlfcn.h>
import "C"
const (
RTLD_LAZY = C.RTLD_LAZY
RTLD_NOW = C.RTLD_NOW
RTLD_GLOBAL = C.RTLD_GLOBAL
RTLD_LOCAL = C.RTLD_LOCAL
RTLD_NODELETE = C.RTLD_NODELETE
RTLD_NOLOAD = C.RTLD_NOLOAD
RTLD_DEEPBIND = C.RTLD_DEEPBIND
)
type DynamicLibrary struct{
Name string
Flags int
handle unsafe.Pointer
}
func New(name string, flags int) *DynamicLibrary {
return &DynamicLibrary{
Name: name,
Flags: flags,
handle: nil,
}
}
func (dl *DynamicLibrary) Open() error {
handle := C.dlopen(C.CString(dl.Name), C.int(dl.Flags))
if handle == C.NULL {
return fmt.Errorf("%s", C.GoString(C.dlerror()))
}
dl.handle = handle
return nil
}
func (dl *DynamicLibrary) Close() error {
err := C.dlclose(dl.handle)
if err != 0 {
return fmt.Errorf("%s", C.GoString(C.dlerror()))
}
return nil
}
func (dl *DynamicLibrary) Lookup(symbol string) error {
C.dlerror() // Clear out any previous errors
C.dlsym(dl.handle, C.CString(symbol))
err := C.dlerror()
if unsafe.Pointer(err) == C.NULL {
return nil
}
return fmt.Errorf("%s", C.GoString(err))
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nvml
import (
"unsafe"
)
import "C"
var cgoAllocsUnknown = new(struct{})
type stringHeader struct {
Data unsafe.Pointer
Len int
}
func clen(n []byte) int {
for i := 0; i < len(n); i++ {
if n[i] == 0 {
return i
}
}
return len(n)
}
func uint32SliceToIntSlice(s []uint32) []int {
ret := make([]int, len(s))
for i := range s {
ret[i] = int(s[i])
}
return ret
}
// packPCharString creates a Go string backed by *C.char and avoids copying.
func packPCharString(p *C.char) (raw string) {
if p != nil && *p != 0 {
h := (*stringHeader)(unsafe.Pointer(&raw))
h.Data = unsafe.Pointer(p)
for *p != 0 {
p = (*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + 1)) // p++
}
h.Len = int(uintptr(unsafe.Pointer(p)) - uintptr(h.Data))
}
return
}
// unpackPCharString represents the data from Go string as *C.char and avoids copying.
func unpackPCharString(str string) (*C.char, *struct{}) {
h := (*stringHeader)(unsafe.Pointer(&str))
return (*C.char)(h.Data), cgoAllocsUnknown
}

View File

@@ -0,0 +1,23 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// WARNING: This file has automatically been generated on Wed, 03 Feb 2021 13:08:24 UTC.
// Code generated by https://git.io/c-for-go. DO NOT EDIT.
#include "nvml.h"
#include <stdlib.h>
#pragma once
#define __CGOGEN 1

1010
vendor/github.com/NVIDIA/go-nvml/pkg/nvml/const.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1962
vendor/github.com/NVIDIA/go-nvml/pkg/nvml/device.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More