Copy files from nvidia-container-toolkit

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar
2021-06-30 13:54:16 +02:00
parent 22fcd022f3
commit d3997eceb2
89 changed files with 8351 additions and 0 deletions

View File

@@ -0,0 +1,52 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
log "github.com/sirupsen/logrus"
)
// list represents a list of config updaters that are applied in order, with later
// configs having preference.
type list struct {
logger *log.Logger
configs []configUpdater
}
var _ configUpdater = (*list)(nil)
func newListWithLogger(logger *log.Logger, ci ...configUpdater) configUpdater {
c := list{
logger: logger,
configs: ci,
}
return &c
}
func (c list) Update(cfg *Config) error {
for i, u := range c.configs {
c.logger.Debugf("Applying config %v: %v", i, u)
err := u.Update(cfg)
if err != nil {
return fmt.Errorf("error applying config %v: %v", i, err)
}
}
return nil
}

View File

@@ -0,0 +1,100 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestComposite(t *testing.T) {
logger, _ := testlog.NewNullLogger()
// Empty list
c := newListWithLogger(logger)
cfg := &Config{}
err := c.Update(cfg)
require.NoError(t, err)
require.EqualValues(t, &Config{}, cfg)
// Add a single mock config
mockConfigUpdater := &configUpdaterMock{
UpdateFunc: func(cfg *Config) error {
cfg.DebugFilePath = "updated"
return nil
},
}
c = newListWithLogger(logger, mockConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.NoError(t, err)
require.EqualValues(t, &Config{DebugFilePath: "updated"}, cfg)
require.Len(t, mockConfigUpdater.UpdateCalls(), 1)
// Reset the calls
mockConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
// define an error config
errorConfigUpdater := &configUpdaterMock{
UpdateFunc: func(cfg *Config) error {
return fmt.Errorf("mock error")
},
}
c = newListWithLogger(logger, errorConfigUpdater, mockConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.Error(t, err)
require.EqualValues(t, &Config{}, cfg)
require.Len(t, errorConfigUpdater.UpdateCalls(), 1)
require.Len(t, mockConfigUpdater.UpdateCalls(), 0)
// Reset the calls
mockConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
errorConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
// Change the order of the config and error
c = newListWithLogger(logger, mockConfigUpdater, errorConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.Error(t, err)
require.EqualValues(t, &Config{DebugFilePath: "updated"}, cfg)
require.Len(t, errorConfigUpdater.UpdateCalls(), 1)
require.Len(t, mockConfigUpdater.UpdateCalls(), 1)
// Reset the calls
mockConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
errorConfigUpdater.calls = struct{ Update []struct{ Config *Config } }{}
// Call the mock twice
c = newListWithLogger(logger, mockConfigUpdater, mockConfigUpdater)
cfg = &Config{}
err = c.Update(cfg)
require.NoError(t, err)
require.EqualValues(t, &Config{DebugFilePath: "updated"}, cfg)
require.Len(t, mockConfigUpdater.UpdateCalls(), 2)
}

View File

@@ -0,0 +1,63 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
log "github.com/sirupsen/logrus"
)
// Config defines the configuration options for the NVIDIA Container Runtime
type Config struct {
// Root defines the root of the file system to be used for locating mounts
Root string
// DebugFilePath defines a log file to print debug output to
DebugFilePath string
// RuntimePath defines the path to an OCI compliant runtime
RuntimePath string
// LogLevel defines the logging level for the application
LogLevel string
}
//go:generate moq -stub -out config_mock.go . configUpdater
// configUpdate represents an interface for applying updates to a config.
type configUpdater interface {
Update(*Config) error
}
// GetConfig returns a config struct with the values resolved. The values are defined in order of
// priority:
// 1. From the associated environment variables
// 2. From the loaded config file
// 3. From the default values defined in the `defaultConfig` function
func GetConfig(logger *log.Logger) (*Config, error) {
cfg := &Config{}
configs := newListWithLogger(logger,
newDefaultConfig(),
newDefaultConfigFileWithLogger(logger),
newConfigFromEnvironment(),
)
err := configs.Update(cfg)
if err != nil {
return nil, fmt.Errorf("error getting config: %v", err)
}
return cfg, nil
}

View File

@@ -0,0 +1,76 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package config
import (
"sync"
)
// Ensure, that configUpdaterMock does implement configUpdater.
// If this is not the case, regenerate this file with moq.
var _ configUpdater = &configUpdaterMock{}
// configUpdaterMock is a mock implementation of configUpdater.
//
// func TestSomethingThatUsesconfigUpdater(t *testing.T) {
//
// // make and configure a mocked configUpdater
// mockedconfigUpdater := &configUpdaterMock{
// UpdateFunc: func(config *Config) error {
// panic("mock out the Update method")
// },
// }
//
// // use mockedconfigUpdater in code that requires configUpdater
// // and then make assertions.
//
// }
type configUpdaterMock struct {
// UpdateFunc mocks the Update method.
UpdateFunc func(config *Config) error
// calls tracks calls to the methods.
calls struct {
// Update holds details about calls to the Update method.
Update []struct {
// Config is the config argument value.
Config *Config
}
}
lockUpdate sync.RWMutex
}
// Update calls UpdateFunc.
func (mock *configUpdaterMock) Update(config *Config) error {
callInfo := struct {
Config *Config
}{
Config: config,
}
mock.lockUpdate.Lock()
mock.calls.Update = append(mock.calls.Update, callInfo)
mock.lockUpdate.Unlock()
if mock.UpdateFunc == nil {
var (
errOut error
)
return errOut
}
return mock.UpdateFunc(config)
}
// UpdateCalls gets all the calls that were made to Update.
// Check the length with:
// len(mockedconfigUpdater.UpdateCalls())
func (mock *configUpdaterMock) UpdateCalls() []struct {
Config *Config
} {
var calls []struct {
Config *Config
}
mock.lockUpdate.RLock()
calls = mock.calls.Update
mock.lockUpdate.RUnlock()
return calls
}

View File

@@ -0,0 +1,58 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"io/ioutil"
"os"
"path"
"path/filepath"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestGetConfigWithCustomConfig(t *testing.T) {
wd, err := os.Getwd()
require.NoError(t, err)
// By default debug is disabled
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")
testDir := path.Join(wd, "test")
filename := path.Join(testDir, configFileRelativePath)
previousConfig, present := os.LookupEnv(configOverride)
os.Setenv(configOverride, testDir)
defer func() {
if present {
os.Setenv(configOverride, previousConfig)
} else {
os.Unsetenv(configOverride)
}
}()
require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
require.NoError(t, ioutil.WriteFile(filename, contents, 0766))
defer func() { require.NoError(t, os.RemoveAll(testDir)) }()
logger, _ := testlog.NewNullLogger()
cfg, err := GetConfig(logger)
require.NoError(t, err)
require.Equal(t, "/nvidia-container-toolkit.log", cfg.DebugFilePath)
}

View File

@@ -0,0 +1,47 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
log "github.com/sirupsen/logrus"
)
type defaultConfig struct{}
var _ configUpdater = (*defaultConfig)(nil)
func newDefaultConfig() configUpdater {
c := defaultConfig{}
return &c
}
// Update defines the the default values for the config options
func (c defaultConfig) Update(cfg *Config) error {
*cfg = Config{
DebugFilePath: "/dev/null",
LogLevel: log.InfoLevel.String(),
}
return nil
}
func getDefaultConfig() *Config {
cfg := &Config{}
defaultConfig{}.Update(cfg)
return cfg
}

View File

@@ -0,0 +1,37 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultConfigUpdate(t *testing.T) {
cfg := &Config{}
require.Empty(t, cfg.DebugFilePath)
c := defaultConfig{}
err := c.Update(cfg)
require.NoError(t, err)
require.Equal(t, "/dev/null", cfg.DebugFilePath)
require.Equal(t, "", cfg.Root)
}

View File

@@ -0,0 +1,61 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"os"
"strings"
)
const (
debugFilePathEnvvarName = "NVIDIA_CONTAINER_RUNTIME_DEBUG"
runtimePathEnvvarName = "NVIDIA_CONTAINER_RUNTIME_PATH"
rootEnvvarName = "NVIDIA_CONTAINER_RUNTIME_ROOT"
logLevelEnvvarName = "NVIDIA_CONTAINER_RUNTIME_LOG_LEVEL"
)
type envConfig struct{}
func newConfigFromEnvironment() configUpdater {
c := envConfig{}
return &c
}
func (c envConfig) Update(cfg *Config) error {
debugFilePathEnvvar, exists := os.LookupEnv(debugFilePathEnvvarName)
if exists && strings.TrimSpace(debugFilePathEnvvar) != "" {
cfg.DebugFilePath = debugFilePathEnvvar
}
runtimePathEnvvar, exists := os.LookupEnv(runtimePathEnvvarName)
if exists && strings.TrimSpace(runtimePathEnvvar) != "" {
cfg.RuntimePath = runtimePathEnvvar
}
rootEnvvar, exists := os.LookupEnv(rootEnvvarName)
if exists && strings.TrimSpace(rootEnvvar) != "" {
cfg.Root = rootEnvvar
}
logLevelEnvvar, exists := os.LookupEnv(logLevelEnvvarName)
if exists && strings.TrimSpace(logLevelEnvvar) != "" {
cfg.LogLevel = logLevelEnvvar
}
return nil
}

View File

@@ -0,0 +1,29 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
// noop implements a configUpdater that does not update a config.
type noop struct{}
func newNoopConfigUpdater() configUpdater {
c := noop{}
return &c
}
func (c noop) Update(cfg *Config) error {
return nil
}

View File

@@ -0,0 +1,36 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNoopDoesNotModifyConfig(t *testing.T) {
cfg := &Config{
DebugFilePath: "test-path",
}
c := newNoopConfigUpdater()
err := c.Update(cfg)
require.NoError(t, err)
require.Equal(t, "test-path", cfg.DebugFilePath)
}

View File

@@ -0,0 +1,158 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/pelletier/go-toml"
log "github.com/sirupsen/logrus"
)
const (
configFileRelativePath = "nvidia-container-runtime/config.toml"
configOverride = "XDG_CONFIG_HOME"
defaultConfigRoot = "/etc"
nvidiaContainerCliSection = "nvidia-container-cli"
nvidiaContainerRuntimeConfigSection = "nvidia-container-runtime"
nvidiaContainerRuntimeExperimentalConfigSection = "nvidia-container-runtime.experimental"
)
type tomlConfig struct {
logger *log.Logger
path string
sections []tomlSection
}
type tomlSection struct {
section string
keys map[string]struct{}
}
func newDefaultConfigFileWithLogger(logger *log.Logger) configUpdater {
configDir := defaultConfigRoot
if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 {
configDir = XDGConfigDir
}
configFilePath := filepath.Join(configDir, configFileRelativePath)
return newConfigFromFileWithLogger(logger, configFilePath)
}
func newConfigFromFileWithLogger(logger *log.Logger, filepath string) configUpdater {
if _, err := os.Stat(filepath); os.IsNotExist(err) {
logger.Warnf("The config file '%v' does not exist", filepath)
return newNoopConfigUpdater()
}
sections := []tomlSection{
{
section: nvidiaContainerRuntimeConfigSection,
},
{
section: nvidiaContainerRuntimeExperimentalConfigSection,
},
{
section: nvidiaContainerCliSection,
keys: map[string]struct{}{
"root": {},
},
},
}
c := tomlConfig{
logger: logger,
path: filepath,
sections: sections,
}
return &c
}
func (c tomlConfig) Update(cfg *Config) error {
configFile, err := os.Open(c.path)
if err != nil {
return fmt.Errorf("error opening config file %v: %v", c.path, err)
}
defer configFile.Close()
return c.updateFromReader(cfg, configFile)
}
func (c tomlConfig) updateFromReader(cfg *Config, reader io.Reader) error {
toml, err := toml.LoadReader(reader)
if err != nil {
return fmt.Errorf("error reading TOML contents: %v", err)
}
for _, section := range c.sections {
if v, ok := section.GetStringFrom(toml, "debug"); ok {
cfg.DebugFilePath = v
}
if v, ok := section.GetStringFrom(toml, "runtime-path"); ok {
cfg.RuntimePath = v
}
if v, ok := section.GetStringFrom(toml, "root"); ok {
cfg.Root = v
}
if v, ok := section.GetStringFrom(toml, "log-level"); ok {
cfg.Root = v
}
}
return nil
}
func (c tomlSection) GetStringFrom(toml *toml.Tree, key string) (string, bool) {
value := c.GetFrom(toml, key)
if value != nil {
if v, ok := value.(string); ok {
return v, ok
}
}
return "", false
}
func (c tomlSection) GetFrom(toml *toml.Tree, key string) interface{} {
if !c.validKey(key) {
return nil
}
return toml.Get(c.configKey(key))
}
func (c tomlSection) validKey(key string) bool {
if c.keys == nil {
return true
}
_, exists := c.keys[key]
return exists
}
func (c tomlSection) configKey(key string) string {
if c.section == "" {
return key
}
return c.section + "." + key
}

View File

@@ -0,0 +1,249 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package config
import (
"fmt"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
"testing"
"testing/iotest"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestUpdateFromReader(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
description string
readerError bool
lines []string
expected *Config
expectedError bool
}{
{
description: "reader error returns error",
readerError: true,
expectedError: true,
expected: getDefaultConfig(),
},
{
description: "empty config returns defaults",
lines: []string{},
expected: &Config{
DebugFilePath: "/dev/null",
LogLevel: "info",
},
},
{
description: "debug output is set",
lines: []string{"nvidia-container-runtime.debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
description: "debug output is set in section",
lines: []string{"[nvidia-container-runtime]", "debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
description: "blank debug is set",
lines: []string{"nvidia-container-runtime.debug=\"\""},
expected: &Config{
DebugFilePath: "",
LogLevel: "info",
},
},
{
description: "non-string debug is ignored",
lines: []string{"nvidia-container-runtime.debug=2"},
expected: &Config{
DebugFilePath: "/dev/null",
LogLevel: "info",
},
},
}
for i, tc := range testCases {
cfg := getDefaultConfig()
c := tomlConfig{
logger: logger,
sections: []tomlSection{
{section: nvidiaContainerRuntimeConfigSection},
},
}
var reader io.Reader
if tc.readerError {
reader = iotest.ErrReader(fmt.Errorf("error"))
} else {
reader = strings.NewReader(strings.Join(tc.lines, "\n"))
}
err := c.updateFromReader(cfg, reader)
if tc.expectedError {
require.Error(t, err, "%d: %v", i, tc.description)
} else {
require.NoError(t, err, "%d: %v", i, tc.description)
}
require.EqualValues(t, tc.expected, cfg, "%d: %v", i, tc.description)
}
}
func TestUpdateFromReaderExperimental(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
readerError bool
lines []string
expected *Config
expectedError bool
}{
{
lines: []string{"nvidia-container-runtime.debug=\"nvidia-container-toolkit.log\""},
expected: getDefaultConfig(),
},
{
lines: []string{"[nvidia-container-runtime]", "debug=\"nvidia-container-toolkit.log\""},
expected: getDefaultConfig(),
},
{
lines: []string{"nvidia-container-runtime.experimental.debug=\"\""},
expected: &Config{
DebugFilePath: "",
LogLevel: "info",
},
},
{
lines: []string{"nvidia-container-runtime.experimental.debug=2"},
expected: getDefaultConfig(),
},
{
lines: []string{"nvidia-container-runtime.experimental.debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
lines: []string{"[nvidia-container-runtime.experimental]", "debug=\"nvidia-container-toolkit.log\""},
expected: &Config{
DebugFilePath: "nvidia-container-toolkit.log",
LogLevel: "info",
},
},
{
lines: []string{
"nvidia-container-runtime.debug=\"nvidia-container-toolkit.log\"",
"nvidia-container-runtime.experimental.debug=\"nvidia-container-exp-toolkit.log\"",
},
expected: &Config{
DebugFilePath: "nvidia-container-exp-toolkit.log",
LogLevel: "info",
},
},
}
for i, tc := range testCases {
cfg := getDefaultConfig()
c := tomlConfig{
logger: logger,
sections: []tomlSection{
{section: nvidiaContainerRuntimeExperimentalConfigSection},
},
}
var reader io.Reader
if tc.readerError {
reader = iotest.ErrReader(fmt.Errorf("error"))
} else {
reader = strings.NewReader(strings.Join(tc.lines, "\n"))
}
err := c.updateFromReader(cfg, reader)
if tc.expectedError {
require.Error(t, err, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
}
require.EqualValues(t, tc.expected, cfg, "%d: %v", i, tc)
}
}
func TestConfigFromFile(t *testing.T) {
wd, err := os.Getwd()
require.NoError(t, err)
// By default debug is disabled
lines := []string{
"[nvidia-container-cli]",
"root = \"/run/nvidia/driver\"",
"[nvidia-container-runtime]",
"#debug = \"/nvidia-container-toolkit.log\"",
"",
"[nvidia-container-runtime.experimental]",
"debug = \"/nvidia-container-toolkit.experimental.log\"",
}
contents := []byte(strings.Join(lines, "\n"))
testDir := path.Join(wd, "test")
filename := path.Join(testDir, configFileRelativePath)
require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
require.NoError(t, ioutil.WriteFile(filename, contents, 0766))
defer func() { require.NoError(t, os.RemoveAll(testDir)) }()
logger, _ := testlog.NewNullLogger()
c := newConfigFromFileWithLogger(logger, filename)
cfg := getDefaultConfig()
err = c.Update(cfg)
require.NoError(t, err)
require.Equal(t, "/nvidia-container-toolkit.experimental.log", cfg.DebugFilePath)
require.Equal(t, "/run/nvidia/driver", cfg.Root)
}
func TestConfigFromNonexistentFileReturnsNoop(t *testing.T) {
logger, _ := testlog.NewNullLogger()
c := newConfigFromFileWithLogger(logger, "/does/not/exist")
n, ok := c.(*noop)
require.True(t, ok)
require.NotNil(t, n)
}
func TestGetConfigKey(t *testing.T) {
require.Equal(t, "key", tomlSection{}.configKey("key"))
require.Equal(t, "section.key", tomlSection{section: "section"}.configKey("key"))
}

View File

@@ -0,0 +1,144 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package main
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/ensure"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/filter"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/modify"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/runtime"
"gitlab.com/nvidia/cloud-native/container-toolkit/cmd/nvidia-container-runtime.experimental/config"
)
const (
visibleDevicesEnvvar = "NVIDIA_VISIBLE_DEVICES"
visibleDevicesVoid = "void"
)
var logger = log.New()
func main() {
cfg, err := config.GetConfig(logger)
if err != nil {
logger.Errorf("Error loading config: %v", err)
os.Exit(1)
}
if cfg.DebugFilePath != "" && cfg.DebugFilePath != "/dev/nul" {
logFile, err := os.OpenFile(cfg.DebugFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
logger.Errorf("Error opening debug log file: %v", err)
os.Exit(1)
}
defer logFile.Close()
logger.SetOutput(logFile)
}
logLevel, err := log.ParseLevel(cfg.LogLevel)
if err == nil {
logger.SetLevel(logLevel)
} else {
logger.Warnf("Invalid log-level '%v'; using '%v'", cfg.LogLevel, logger.Level.String())
}
logger.Infof("Starting nvidia-container-runtime: %v", os.Args)
logger.Debugf("Using config=%+v", cfg)
if err := run(cfg, os.Args); err != nil {
logger.Errorf("Error running runtime: %v", err)
os.Exit(1)
}
}
func run(cfg *config.Config, args []string) error {
logger.Debugf("running with args=%v", args)
// We create a low-level runtime
lowLevelRuntime, err := createLowLevelRuntime(logger, cfg)
if err != nil {
return fmt.Errorf("error constructing low-level runtime: %v", err)
}
if !oci.HasCreateSubcommand(args) {
logger.Infof("No modification of OCI specification required")
logger.Infof("Forwarding command to runtime")
return lowLevelRuntime.Exec(args)
}
// We create the OCI spec that is to be modified
ociSpec, bundleDir, err := oci.NewSpecFromArgs(args)
if err != nil {
return fmt.Errorf("error constructing OCI spec: %v", err)
}
err = ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification: %v", err)
}
visibleDevices, exists := ociSpec.LookupEnv(visibleDevicesEnvvar)
if !exists || visibleDevices == "" || visibleDevices == visibleDevicesVoid {
logger.Infof("Using low-level runtime: %v=%v (exists=%v)", visibleDevicesEnvvar, visibleDevices, exists)
return lowLevelRuntime.Exec(os.Args)
}
// We create the modifier that will be applied by the Modifying Runtime Wrapper
modifier, err := createModifier(cfg.Root, bundleDir, visibleDevices, ociSpec)
if err != nil {
return fmt.Errorf("error constructing modifer: %v", err)
}
// We construct the Modifying runtime
r := runtime.NewModifyingRuntimeWrapperWithLogger(logger, lowLevelRuntime, ociSpec, modifier)
return r.Exec(os.Args)
}
func createLowLevelRuntime(logger *log.Logger, cfg *config.Config) (oci.Runtime, error) {
if cfg.RuntimePath == "" {
return oci.NewLowLevelRuntimeWithLogger(logger, "docker-runc", "runc")
}
logger.Infof("Creating runtime with path %v", cfg.RuntimePath)
return oci.NewRuntimeForPathWithLogger(logger, cfg.RuntimePath)
}
func createModifier(root string, bundleDir string, visibleDevices string, env filter.EnvLookup) (modify.Modifier, error) {
// We set up the modifier using discovery
discovered, err := discover.NewNVMLServerWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
// We apply a filter to the discovered devices
selected := filter.NewSelectDevicesFromWithLogger(logger, discovered, visibleDevices, env)
// We ensure that the selected devices are available
available := ensure.NewEnsureDevicesWithLogger(logger, selected, root)
// We construct the modifer for the OCI spec
modifier := modify.NewModifierWithLoggerFor(logger, available, root, bundleDir)
return modifier, nil
}

18
docker/Dockerfile.devel Normal file
View File

@@ -0,0 +1,18 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG GOLANG_VERSION=x.x.x
FROM golang:${GOLANG_VERSION}
RUN go get -u golang.org/x/lint/golint
RUN go get -u github.com/matryer/moq

52
examples/discover/main.go Normal file
View File

@@ -0,0 +1,52 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package main
import (
"encoding/json"
"os"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
func main() {
log.Infof("Starting device discovery with NVML")
d, err := discover.NewNVMLServer("")
devices, err := d.Devices()
if err != nil {
log.Errorf("Error discovering devices: %v", err)
return
}
mounts, err := d.Mounts()
if err != nil {
log.Errorf("Error discovering mounts: %v", err)
return
}
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
log.Infof("Discovered devices:")
enc.Encode(devices)
log.Infof("Discovered libraries:")
enc.Encode(mounts)
}

65
examples/filter/main.go Normal file
View File

@@ -0,0 +1,65 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package main
import (
"encoding/json"
"os"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/filter"
)
func main() {
d, err := discover.NewNVMLServer("")
if err != nil {
log.Errorf("Error discovering devices: %v", err)
}
selected := filter.NewSelectDevicesFrom(d, "all", nil)
devices, err := selected.Devices()
if err != nil {
log.Errorf("Error discovering devices: %v", err)
return
}
mounts, err := selected.Mounts()
if err != nil {
log.Errorf("Error discovering mounts: %v", err)
return
}
hooks, err := selected.Hooks()
if err != nil {
log.Errorf("Error discovering hooks: %v", err)
return
}
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
log.Infof("Discovered devices:")
enc.Encode(devices)
log.Infof("Discovered libraries:")
enc.Encode(mounts)
log.Infof("Discovered hook:")
enc.Encode(hooks)
}

View File

@@ -0,0 +1,26 @@
package main
import (
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/ldcache"
)
var logger = log.StandardLogger()
func main() {
logger.SetLevel(log.DebugLevel)
logger.Infof("Starting device discovery with NVML")
cache, err := ldcache.NewLDCacheWithLogger(logger, "/run/nvidia/driver")
if err != nil {
logger.Errorf("Error loading ldcache: %v", err)
return
}
defer cache.Close()
libs32, libs64 := cache.Lookup("lib")
logger.Infof("32-bit: %v", libs32)
logger.Infof("64-bit: %v", libs64)
}

239
internal/ldcache/ldcache.go Normal file
View File

@@ -0,0 +1,239 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
// Adapted from https://github.com/rai-project/ldcache
package ldcache
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
)
const ldcachePath = "/etc/ld.so.cache"
const (
magicString1 = "ld.so-1.7.0"
magicString2 = "glibc-ld.so.cache"
magicVersion = "1.1"
)
const (
flagTypeMask = 0x00ff
flagTypeELF = 0x0001
flagArchMask = 0xff00
flagArchI386 = 0x0000
flagArchX8664 = 0x0300
flagArchX32 = 0x0800
flagArchPpc64le = 0x0500
)
var ErrInvalidCache = errors.New("invalid ld.so.cache file")
type Header1 struct {
Magic [len(magicString1) + 1]byte // include null delimiter
NLibs uint32
}
type Entry1 struct {
Flags int32
Key, Value uint32
}
type Header2 struct {
Magic [len(magicString2)]byte
Version [len(magicVersion)]byte
NLibs uint32
TableSize uint32
_ [3]uint32 // unused
_ uint64 // force 8 byte alignment
}
type Entry2 struct {
Flags int32
Key, Value uint32
OSVersion uint32
HWCap uint64
}
type LDCache struct {
*bytes.Reader
data, libs []byte
header Header2
entries []Entry2
root string
logger *log.Logger
}
func NewLDCacheWithLogger(logger *log.Logger, root string) (*LDCache, error) {
return openWithRoot(logger, root)
}
func Open() (*LDCache, error) {
return openWithRoot(log.StandardLogger(), "")
}
func openWithRoot(logger *log.Logger, root string) (*LDCache, error) {
path := filepath.Join(root, ldcachePath)
logger.Debugf("Opening ld.conf at %v", path)
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
d, err := syscall.Mmap(int(f.Fd()), 0, int(fi.Size()),
syscall.PROT_READ, syscall.MAP_PRIVATE)
if err != nil {
return nil, err
}
cache := &LDCache{
data: d,
Reader: bytes.NewReader(d),
root: root,
logger: logger,
}
return cache, cache.parse()
}
func (c *LDCache) Close() error {
return syscall.Munmap(c.data)
}
func (c *LDCache) Magic() string {
return string(c.header.Magic[:])
}
func (c *LDCache) Version() string {
return string(c.header.Version[:])
}
func strn(b []byte, n int) string {
return string(b[:n])
}
func (c *LDCache) parse() error {
var header Header1
// Check for the old format (< glibc-2.2)
if c.Len() <= int(unsafe.Sizeof(header)) {
return ErrInvalidCache
}
if strn(c.data, len(magicString1)) == magicString1 {
if err := binary.Read(c, binary.LittleEndian, &header); err != nil {
return err
}
n := int64(header.NLibs) * int64(unsafe.Sizeof(Entry1{}))
offset, err := c.Seek(n, 1) // skip old entries
if err != nil {
return err
}
n = (-offset) & int64(unsafe.Alignof(c.header)-1)
_, err = c.Seek(n, 1) // skip padding
if err != nil {
return err
}
}
c.libs = c.data[c.Size()-int64(c.Len()):] // kv offsets start here
if err := binary.Read(c, binary.LittleEndian, &c.header); err != nil {
return err
}
if c.Magic() != magicString2 || c.Version() != magicVersion {
return ErrInvalidCache
}
c.entries = make([]Entry2, c.header.NLibs)
if err := binary.Read(c, binary.LittleEndian, &c.entries); err != nil {
return err
}
return nil
}
func (c *LDCache) Lookup(libs ...string) (paths32, paths64 []string) {
c.logger.Debugf("Looking up %v in cache", libs)
type void struct{}
var paths *[]string
set := make(map[string]void)
prefix := make([][]byte, len(libs))
for i := range libs {
prefix[i] = []byte(libs[i])
}
for _, e := range c.entries {
if ((e.Flags & flagTypeMask) & flagTypeELF) == 0 {
continue
}
switch e.Flags & flagArchMask {
case flagArchX8664:
fallthrough
case flagArchPpc64le:
paths = &paths64
case flagArchX32:
fallthrough
case flagArchI386:
paths = &paths32
default:
continue
}
if e.Key > uint32(len(c.libs)) || e.Value > uint32(len(c.libs)) {
continue
}
lib := c.libs[e.Key:]
value := c.libs[e.Value:]
for _, p := range prefix {
if bytes.HasPrefix(lib, p) {
n := bytes.IndexByte(value, 0)
if n < 0 {
break
}
name := filepath.Join(c.root, strn(value, n))
c.logger.Debugf("checking %v", string(name))
path, err := filepath.EvalSymlinks(name)
if err != nil {
c.logger.Debugf("could not resolve symlink for %v", name)
break
}
if _, ok := set[path]; ok {
break
}
set[path] = void{}
*paths = append(*paths, path)
break
}
}
}
return
}

78
internal/lookup/file.go Normal file
View File

@@ -0,0 +1,78 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
import (
"fmt"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
)
type file struct {
logger *log.Logger
prefixes []string
filter func(string) error
}
func NewFileLocator(root string) Locator {
return NewFileLocatorWithLogger(log.StandardLogger(), root)
}
func NewFileLocatorWithLogger(logger *log.Logger, root string) Locator {
l := file{
logger: logger,
prefixes: []string{root},
filter: assertFile,
}
return &l
}
var _ Locator = (*file)(nil)
func (p file) Locate(filename string) ([]string, error) {
var filenames []string
for _, prefix := range p.prefixes {
candidate := filepath.Join(prefix, filename)
p.logger.Debugf("Checking candidate '%v'", candidate)
err := p.filter(candidate)
if err != nil {
p.logger.Debugf("Candidate '%v' does not meet requirements: %v", candidate, err)
continue
}
filenames = append(filenames, candidate)
}
if len(filename) == 0 {
return nil, fmt.Errorf("file %v not found", filename)
}
return filenames, nil
}
func assertFile(filename string) error {
info, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("error getting info for %v: %v", filename, err)
}
if info.IsDir() {
return fmt.Errorf("specified path '%v' is a directory", filename)
}
return nil
}

View File

@@ -0,0 +1,65 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
import (
"fmt"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/ldcache"
)
type library struct {
logger *log.Logger
cache *ldcache.LDCache
}
var _ Locator = (*library)(nil)
// NewLibraryLocatorWithLogger creates a library locator using the standard logger
func NewLibraryLocator(root string) (Locator, error) {
return NewLibraryLocatorWithLogger(log.StandardLogger(), root)
}
// NewLibraryLocatorWithLogger creates a library locator using the specified logger.
func NewLibraryLocatorWithLogger(logger *log.Logger, root string) (Locator, error) {
logger.Infof("Reading ldcache at %v", root)
cache, err := ldcache.NewLDCacheWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error loading ldcache: %v", err)
}
l := library{
logger: logger,
cache: cache,
}
return &l, nil
}
func (l library) Locate(libname string) ([]string, error) {
paths32, paths64 := l.cache.Lookup(libname)
if len(paths32) > 0 {
l.logger.Warnf("Ignoring 32-bit libraries for %v: %v", libname, paths32)
}
if len(paths64) == 0 {
return nil, fmt.Errorf("64-bit library %v not found", libname)
}
return paths64, nil
}

24
internal/lookup/lookup.go Normal file
View File

@@ -0,0 +1,24 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
//go:generate moq -stub -out lookup_mock.go . Locator
// Locator defines the interface for locating files on a system.
type Locator interface {
Locate(string) ([]string, error)
}

View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package lookup
import (
"sync"
)
// Ensure, that LocatorMock does implement Locator.
// If this is not the case, regenerate this file with moq.
var _ Locator = &LocatorMock{}
// LocatorMock is a mock implementation of Locator.
//
// func TestSomethingThatUsesLocator(t *testing.T) {
//
// // make and configure a mocked Locator
// mockedLocator := &LocatorMock{
// LocateFunc: func(s string) ([]string, error) {
// panic("mock out the Locate method")
// },
// }
//
// // use mockedLocator in code that requires Locator
// // and then make assertions.
//
// }
type LocatorMock struct {
// LocateFunc mocks the Locate method.
LocateFunc func(s string) ([]string, error)
// calls tracks calls to the methods.
calls struct {
// Locate holds details about calls to the Locate method.
Locate []struct {
// S is the s argument value.
S string
}
}
lockLocate sync.RWMutex
}
// Locate calls LocateFunc.
func (mock *LocatorMock) Locate(s string) ([]string, error) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLocate.Lock()
mock.calls.Locate = append(mock.calls.Locate, callInfo)
mock.lockLocate.Unlock()
if mock.LocateFunc == nil {
var (
stringsOut []string
errOut error
)
return stringsOut, errOut
}
return mock.LocateFunc(s)
}
// LocateCalls gets all the calls that were made to Locate.
// Check the length with:
// len(mockedLocator.LocateCalls())
func (mock *LocatorMock) LocateCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLocate.RLock()
calls = mock.calls.Locate
mock.lockLocate.RUnlock()
return calls
}

94
internal/lookup/path.go Normal file
View File

@@ -0,0 +1,94 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package lookup
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
const (
envPath = "PATH"
)
var defaultPaths = []string{"/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"}
type path struct {
file
}
func NewPathLocator(root string) Locator {
return NewPathLocatorWithLogger(log.StandardLogger(), root)
}
func NewPathLocatorWithLogger(logger *log.Logger, root string) Locator {
pathEnv := os.Getenv(envPath)
paths := filepath.SplitList(pathEnv)
if root != "" {
paths = append(paths, defaultPaths...)
}
var prefixes []string
for _, dir := range paths {
prefixes = append(prefixes, filepath.Join(root, dir))
}
l := path{
file: file{
logger: logger,
prefixes: prefixes,
filter: assertExecutable,
},
}
return &l
}
var _ Locator = (*path)(nil)
func (p path) Locate(filename string) ([]string, error) {
// For absolute paths we ensure that it is executable
if strings.Contains(filename, "/") {
err := assertExecutable(filename)
if err != nil {
return nil, fmt.Errorf("absolute path %v is not an executable file: %v", filename, err)
}
return []string{filename}, nil
}
return p.file.Locate(filename)
}
func assertExecutable(filename string) error {
err := assertFile(filename)
if err != nil {
return err
}
info, err := os.Stat(filename)
if err != nil {
return err
}
if info.Mode()&0111 == 0 {
return fmt.Errorf("specified file '%v' is not executable", filename)
}
return nil
}

141
internal/nvcaps/nvcaps.go Normal file
View File

@@ -0,0 +1,141 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package nvcaps
import (
"bufio"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strconv"
"strings"
)
const (
nvidiaProcDriverPath = "/proc/driver/nvidia"
nvidiaCapabilitiesPath = nvidiaProcDriverPath + "/capabilities"
nvcapsProcDriverPath = "/proc/driver/nvidia-caps"
nvcapsMigMinorsPath = nvcapsProcDriverPath + "/mig-minors"
nvcapsDevicePath = "/dev/nvidia-caps"
)
// MigMinor represents the minor number of a MIG device
type MigMinor int
// MigCap represents the path to a MIG cap file
type MigCap string
// LoadMigMinors loads the MIG minors file and returns its contents as a map
func LoadMigMinors() (map[MigCap]MigMinor, error) {
// Open nvcapsMigMinorsPath for walking.
// If the nvcapsMigMinorsPath does not exist, then we are not on a MIG
// capable machine, so there is nothing to do.
// The format of this file is discussed in:
// https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#unique_1576522674
minorsFile, err := os.Open(nvcapsMigMinorsPath)
if os.IsNotExist(err) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("error opening MIG minors file: %v", err)
}
defer minorsFile.Close()
return processMinorsFile(minorsFile), nil
}
func processMinorsFile(minorsFile io.Reader) map[MigCap]MigMinor {
// Walk each line of nvcapsMigMinorsPath and construct a mapping of nvidia
// capabilities path to device minor for that capability
migCaps := make(map[MigCap]MigMinor)
scanner := bufio.NewScanner(minorsFile)
for scanner.Scan() {
cap, minor, err := processMigMinorsLine(scanner.Text())
if err != nil {
log.Printf("Skipping line in MIG minors file: %v", err)
continue
}
migCaps[cap] = minor
}
return migCaps
}
func processMigMinorsLine(line string) (MigCap, MigMinor, error) {
parts := strings.Split(line, " ")
if len(parts) != 2 {
return "", 0, fmt.Errorf("error processing line: %v", line)
}
migCap := MigCap(parts[0])
if !migCap.isValid() {
return "", 0, fmt.Errorf("invalid MIG minors line: '%v'", line)
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return "", 0, fmt.Errorf("error reading MIG minor from '%v': %v", line, err)
}
return migCap, MigMinor(minor), nil
}
func (m MigCap) isValid() bool {
cap := string(m)
switch cap {
case "config", "monitor":
return true
default:
var gpu int
var gi int
var ci int
// Loog for a CI access file
n, _ := fmt.Sscanf(cap, "gpu%d/gi%d/ci%d/access", &gpu, &gi, &ci)
if n == 3 {
return true
}
// Look for a GI access file
n, _ = fmt.Sscanf(cap, "gpu%d/gi%d/access %d", &gpu, &gi)
if n == 2 {
return true
}
}
return false
}
// ProcPath returns the proc path associated with the MIG capability
func (m MigCap) ProcPath() string {
id := string(m)
var path string
switch id {
case "config", "monitor":
path = "mig/" + id
default:
parts := strings.SplitN(id, "/", 2)
path = strings.Join([]string{parts[0], "mig", parts[1]}, "/")
}
return filepath.Join(nvidiaCapabilitiesPath, path)
}
// DevicePath returns the path for the nvidia-caps device with the specified
// minor number
func (m MigMinor) DevicePath() string {
return fmt.Sprintf(nvcapsDevicePath+"/nvidia-cap%d", m)
}

View File

@@ -0,0 +1,94 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package nvcaps
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestProcessMinorsFile(t *testing.T) {
testCases := []struct {
lines []string
expected map[MigCap]MigMinor
}{
{[]string{}, map[MigCap]MigMinor{}},
{[]string{"invalidLine"}, map[MigCap]MigMinor{}},
{[]string{"config 1"}, map[MigCap]MigMinor{"config": 1}},
{[]string{"gpu0/gi0/ci0/access 4"}, map[MigCap]MigMinor{"gpu0/gi0/ci0/access": 4}},
{[]string{"config 1", "invalidLine"}, map[MigCap]MigMinor{"config": 1}},
{[]string{"config 1", "gpu0/gi0/ci0/access 4"}, map[MigCap]MigMinor{"config": 1, "gpu0/gi0/ci0/access": 4}},
}
for _, tc := range testCases {
contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := processMinorsFile(contents)
require.Equalf(t, tc.expected, d, "testCase: %v", tc)
}
}
func TestProcessMigMinorsLine(t *testing.T) {
testCases := []struct {
line string
cap MigCap
minor MigMinor
err bool
}{
{"config 1", "config", 1, false},
{"monitor 2", "monitor", 2, false},
{"gpu0/gi0/access 3", "gpu0/gi0/access", 3, false},
{"gpu0/gi0/ci0/access 4", "gpu0/gi0/ci0/access", 4, false},
{"notconfig 99", "", 0, true},
{"config notanint", "", 0, true},
{"", "", 0, true},
}
for _, tc := range testCases {
cap, minor, err := processMigMinorsLine(tc.line)
require.Equalf(t, tc.cap, cap, "testCase: %v", tc)
require.Equalf(t, tc.minor, minor, "testCase: %v", tc)
if tc.err {
require.Errorf(t, err, "testCase: %v", tc)
} else {
require.NoErrorf(t, err, "testCase: %v", tc)
}
}
}
func TestMigCapProcPaths(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"config", "/proc/driver/nvidia/capabilities/mig/config"},
{"monitor", "/proc/driver/nvidia/capabilities/mig/monitor"},
{"gpu0/gi0/access", "/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"},
{"gpu0/gi0/ci0/access", "/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"},
}
for _, tc := range testCases {
m := MigCap(tc.input)
require.Equal(t, tc.expected, m.ProcPath())
}
}
func TestMigMinorDevicePath(t *testing.T) {
m := MigMinor(0)
require.Equal(t, "/dev/nvidia-caps/nvidia-cap0", m.DevicePath())
}

79
internal/nvml/consts.go Normal file
View File

@@ -0,0 +1,79 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
const (
SUCCESS = nvml.SUCCESS
ERROR_UNINITIALIZED = nvml.ERROR_UNINITIALIZED
ERROR_INVALID_ARGUMENT = nvml.ERROR_INVALID_ARGUMENT
ERROR_NOT_SUPPORTED = nvml.ERROR_NOT_SUPPORTED
ERROR_NO_PERMISSION = nvml.ERROR_NO_PERMISSION
ERROR_ALREADY_INITIALIZED = nvml.ERROR_ALREADY_INITIALIZED
ERROR_NOT_FOUND = nvml.ERROR_NOT_FOUND
ERROR_INSUFFICIENT_SIZE = nvml.ERROR_INSUFFICIENT_SIZE
ERROR_INSUFFICIENT_POWER = nvml.ERROR_INSUFFICIENT_POWER
ERROR_DRIVER_NOT_LOADED = nvml.ERROR_DRIVER_NOT_LOADED
ERROR_TIMEOUT = nvml.ERROR_TIMEOUT
ERROR_IRQ_ISSUE = nvml.ERROR_IRQ_ISSUE
ERROR_LIBRARY_NOT_FOUND = nvml.ERROR_LIBRARY_NOT_FOUND
ERROR_FUNCTION_NOT_FOUND = nvml.ERROR_FUNCTION_NOT_FOUND
ERROR_CORRUPTED_INFOROM = nvml.ERROR_CORRUPTED_INFOROM
ERROR_GPU_IS_LOST = nvml.ERROR_GPU_IS_LOST
ERROR_RESET_REQUIRED = nvml.ERROR_RESET_REQUIRED
ERROR_OPERATING_SYSTEM = nvml.ERROR_OPERATING_SYSTEM
ERROR_LIB_RM_VERSION_MISMATCH = nvml.ERROR_LIB_RM_VERSION_MISMATCH
ERROR_IN_USE = nvml.ERROR_IN_USE
ERROR_MEMORY = nvml.ERROR_MEMORY
ERROR_NO_DATA = nvml.ERROR_NO_DATA
ERROR_VGPU_ECC_NOT_SUPPORTED = nvml.ERROR_VGPU_ECC_NOT_SUPPORTED
ERROR_INSUFFICIENT_RESOURCES = nvml.ERROR_INSUFFICIENT_RESOURCES
ERROR_UNKNOWN = nvml.ERROR_UNKNOWN
)
const (
DEVICE_MIG_ENABLE = nvml.DEVICE_MIG_ENABLE
DEVICE_MIG_DISABLE = nvml.DEVICE_MIG_DISABLE
)
const (
GPU_INSTANCE_PROFILE_1_SLICE = nvml.GPU_INSTANCE_PROFILE_1_SLICE
GPU_INSTANCE_PROFILE_2_SLICE = nvml.GPU_INSTANCE_PROFILE_2_SLICE
GPU_INSTANCE_PROFILE_3_SLICE = nvml.GPU_INSTANCE_PROFILE_3_SLICE
GPU_INSTANCE_PROFILE_4_SLICE = nvml.GPU_INSTANCE_PROFILE_4_SLICE
GPU_INSTANCE_PROFILE_7_SLICE = nvml.GPU_INSTANCE_PROFILE_7_SLICE
GPU_INSTANCE_PROFILE_8_SLICE = nvml.GPU_INSTANCE_PROFILE_8_SLICE
GPU_INSTANCE_PROFILE_COUNT = nvml.GPU_INSTANCE_PROFILE_COUNT
)
const (
COMPUTE_INSTANCE_PROFILE_1_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE
COMPUTE_INSTANCE_PROFILE_2_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_2_SLICE
COMPUTE_INSTANCE_PROFILE_3_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_3_SLICE
COMPUTE_INSTANCE_PROFILE_4_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_4_SLICE
COMPUTE_INSTANCE_PROFILE_7_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_7_SLICE
COMPUTE_INSTANCE_PROFILE_8_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_8_SLICE
COMPUTE_INSTANCE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_PROFILE_COUNT
)
const (
COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED
COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT
)

594
internal/nvml/mock.go Normal file
View File

@@ -0,0 +1,594 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import "fmt"
type MockServer struct {
Devices []Device
}
type MockLunaServer struct {
MockServer
}
type MockA100Device struct {
Index int
MinorNumber int
MigMode int
GpuInstances map[*MockA100GpuInstance]struct{}
GpuInstanceCounter uint32
}
type MockA100GpuInstance struct {
Info GpuInstanceInfo
ComputeInstances map[*MockA100ComputeInstance]struct{}
ComputeInstanceCounter uint32
}
type MockA100ComputeInstance struct {
Info ComputeInstanceInfo
}
var _ Interface = (*MockLunaServer)(nil)
var _ Device = (*MockA100Device)(nil)
var _ GpuInstance = (*MockA100GpuInstance)(nil)
var _ ComputeInstance = (*MockA100ComputeInstance)(nil)
var MockA100MIGProfiles = struct {
GpuInstanceProfiles map[int]GpuInstanceProfileInfo
ComputeInstanceProfiles map[int]map[int]ComputeInstanceProfileInfo
}{
GpuInstanceProfiles: map[int]GpuInstanceProfileInfo{
GPU_INSTANCE_PROFILE_1_SLICE: {
Id: GPU_INSTANCE_PROFILE_1_SLICE,
IsP2pSupported: 0,
SliceCount: 1,
InstanceCount: 7,
MultiprocessorCount: 1,
CopyEngineCount: 1,
DecoderCount: 0,
EncoderCount: 0,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 5120,
},
GPU_INSTANCE_PROFILE_2_SLICE: {
Id: GPU_INSTANCE_PROFILE_2_SLICE,
IsP2pSupported: 0,
SliceCount: 2,
InstanceCount: 3,
MultiprocessorCount: 2,
CopyEngineCount: 2,
DecoderCount: 1,
EncoderCount: 1,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 10240,
},
GPU_INSTANCE_PROFILE_3_SLICE: {
Id: GPU_INSTANCE_PROFILE_3_SLICE,
IsP2pSupported: 0,
SliceCount: 3,
InstanceCount: 2,
MultiprocessorCount: 3,
CopyEngineCount: 4,
DecoderCount: 2,
EncoderCount: 2,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 20480,
},
GPU_INSTANCE_PROFILE_4_SLICE: {
Id: GPU_INSTANCE_PROFILE_4_SLICE,
IsP2pSupported: 0,
SliceCount: 4,
InstanceCount: 1,
MultiprocessorCount: 4,
CopyEngineCount: 4,
DecoderCount: 2,
EncoderCount: 2,
JpegCount: 0,
OfaCount: 0,
MemorySizeMB: 20480,
},
GPU_INSTANCE_PROFILE_7_SLICE: {
Id: GPU_INSTANCE_PROFILE_7_SLICE,
IsP2pSupported: 0,
SliceCount: 7,
InstanceCount: 1,
MultiprocessorCount: 7,
CopyEngineCount: 8,
DecoderCount: 5,
EncoderCount: 5,
JpegCount: 1,
OfaCount: 1,
MemorySizeMB: 40960,
},
},
ComputeInstanceProfiles: map[int]map[int]ComputeInstanceProfileInfo{
GPU_INSTANCE_PROFILE_1_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 1,
MultiprocessorCount: 1,
SharedCopyEngineCount: 1,
SharedDecoderCount: 0,
SharedEncoderCount: 0,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_2_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 2,
MultiprocessorCount: 1,
SharedCopyEngineCount: 2,
SharedDecoderCount: 1,
SharedEncoderCount: 1,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 1,
MultiprocessorCount: 2,
SharedCopyEngineCount: 2,
SharedDecoderCount: 1,
SharedEncoderCount: 1,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_3_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 3,
MultiprocessorCount: 1,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 1,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 1,
MultiprocessorCount: 2,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_3_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_3_SLICE,
SliceCount: 3,
InstanceCount: 1,
MultiprocessorCount: 3,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 0,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_4_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 4,
MultiprocessorCount: 1,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 2,
MultiprocessorCount: 2,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
COMPUTE_INSTANCE_PROFILE_4_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_4_SLICE,
SliceCount: 4,
InstanceCount: 1,
MultiprocessorCount: 4,
SharedCopyEngineCount: 4,
SharedDecoderCount: 2,
SharedEncoderCount: 2,
SharedJpegCount: 0,
SharedOfaCount: 0,
},
},
GPU_INSTANCE_PROFILE_7_SLICE: {
COMPUTE_INSTANCE_PROFILE_1_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_1_SLICE,
SliceCount: 1,
InstanceCount: 7,
MultiprocessorCount: 1,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_2_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_2_SLICE,
SliceCount: 2,
InstanceCount: 3,
MultiprocessorCount: 2,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_3_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_3_SLICE,
SliceCount: 3,
InstanceCount: 2,
MultiprocessorCount: 3,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_4_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_4_SLICE,
SliceCount: 4,
InstanceCount: 1,
MultiprocessorCount: 4,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
COMPUTE_INSTANCE_PROFILE_7_SLICE: {
Id: COMPUTE_INSTANCE_PROFILE_7_SLICE,
SliceCount: 7,
InstanceCount: 1,
MultiprocessorCount: 7,
SharedCopyEngineCount: 8,
SharedDecoderCount: 5,
SharedEncoderCount: 5,
SharedJpegCount: 1,
SharedOfaCount: 1,
},
},
},
}
func NewMockNVMLServer(devices ...Device) Interface {
return &MockServer{
Devices: devices,
}
}
func NewMockNVMLOnLunaServer() Interface {
devices := []Device{
NewMockA100Device(0),
NewMockA100Device(1),
NewMockA100Device(2),
NewMockA100Device(3),
NewMockA100Device(4),
NewMockA100Device(5),
NewMockA100Device(6),
NewMockA100Device(7),
}
return NewMockNVMLServer(devices...)
}
func NewMockA100Device(index int) Device {
return &MockA100Device{
Index: index,
GpuInstances: make(map[*MockA100GpuInstance]struct{}),
GpuInstanceCounter: 0,
}
}
func NewMockA100GpuInstance(info GpuInstanceInfo) GpuInstance {
return &MockA100GpuInstance{
Info: info,
ComputeInstances: make(map[*MockA100ComputeInstance]struct{}),
ComputeInstanceCounter: 0,
}
}
func NewMockA100ComputeInstance(info ComputeInstanceInfo) ComputeInstance {
return &MockA100ComputeInstance{
Info: info,
}
}
func (n *MockServer) Init() Return {
return MockReturn(SUCCESS)
}
func (n *MockServer) Shutdown() Return {
return MockReturn(SUCCESS)
}
func (n *MockServer) DeviceGetCount() (int, Return) {
return len(n.Devices), MockReturn(SUCCESS)
}
func (n *MockServer) DeviceGetHandleByIndex(index int) (Device, Return) {
if index < 0 || index >= len(n.Devices) {
return nil, MockReturn(ERROR_INVALID_ARGUMENT)
}
return n.Devices[index], MockReturn(SUCCESS)
}
func (n *MockServer) SystemGetDriverVersion() (string, Return) {
return "999.99", MockReturn(SUCCESS)
}
func (d *MockA100Device) GetIndex() (int, Return) {
return d.Index, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetPciInfo() (PciInfo, Return) {
var busID [32]int8
for i, b := range []byte("0000FFFF:FF:FF.F") {
busID[i] = int8(b)
}
p := PciInfo{
BusId: busID,
PciDeviceId: 0x20B010DE,
}
return p, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetUUID() (string, Return) {
return fmt.Sprintf("GPU-%d", d.Index), MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMinorNumber() (int, Return) {
return d.MinorNumber, MockReturn(SUCCESS)
}
func (d *MockA100Device) SetMigMode(mode int) (Return, Return) {
d.MigMode = mode
return MockReturn(SUCCESS), MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMigMode() (int, int, Return) {
return d.MigMode, d.MigMode, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetGpuInstanceProfileInfo(giProfileId int) (GpuInstanceProfileInfo, Return) {
if giProfileId < 0 || giProfileId >= GPU_INSTANCE_PROFILE_COUNT {
return GpuInstanceProfileInfo{}, MockReturn(ERROR_INVALID_ARGUMENT)
}
if _, exists := MockA100MIGProfiles.GpuInstanceProfiles[giProfileId]; !exists {
return GpuInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
return MockA100MIGProfiles.GpuInstanceProfiles[giProfileId], MockReturn(SUCCESS)
}
func (d *MockA100Device) CreateGpuInstance(info *GpuInstanceProfileInfo) (GpuInstance, Return) {
giInfo := GpuInstanceInfo{
Device: d,
Id: d.GpuInstanceCounter,
ProfileId: info.Id,
}
d.GpuInstanceCounter++
gi := NewMockA100GpuInstance(giInfo)
d.GpuInstances[gi.(*MockA100GpuInstance)] = struct{}{}
return gi, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
var gis []GpuInstance
for gi := range d.GpuInstances {
if gi.Info.ProfileId == info.Id {
gis = append(gis, gi)
}
}
return gis, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMaxMigDeviceCount() (int, Return) {
var count int
for gi := range d.GpuInstances {
count = count + int(gi.ComputeInstanceCounter)
}
return count, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetMigDeviceHandleByIndex(Index int) (Device, Return) {
var count int
for gi := range d.GpuInstances {
if count+int(gi.ComputeInstanceCounter) < Index {
count = count + int(gi.ComputeInstanceCounter)
continue
}
for ci := range gi.ComputeInstances {
if count < Index {
count++
continue
}
return ci, MockReturn(SUCCESS)
}
}
return nil, MockReturn(ERROR_NOT_FOUND)
}
func (d *MockA100Device) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
return nil, MockReturn(ERROR_NOT_SUPPORTED)
}
func (d *MockA100Device) IsMigDeviceHandle() (bool, Return) {
return false, MockReturn(SUCCESS)
}
func (d *MockA100Device) GetComputeInstanceId() (int, Return) {
panic("Not implemented: GetComputeInstanceId")
}
func (d *MockA100Device) GetGPUInstanceId() (int, Return) {
panic("Not implemented: GetGPUInstanceId")
}
func (gi *MockA100GpuInstance) GetInfo() (GpuInstanceInfo, Return) {
return gi.Info, MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) GetComputeInstanceProfileInfo(ciProfileId int, ciEngProfileId int) (ComputeInstanceProfileInfo, Return) {
if ciProfileId < 0 || ciProfileId >= COMPUTE_INSTANCE_PROFILE_COUNT {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_INVALID_ARGUMENT)
}
if ciEngProfileId != COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
giProfileId := int(gi.Info.ProfileId)
if _, exists := MockA100MIGProfiles.ComputeInstanceProfiles[giProfileId]; !exists {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
if _, exists := MockA100MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId]; !exists {
return ComputeInstanceProfileInfo{}, MockReturn(ERROR_NOT_SUPPORTED)
}
return MockA100MIGProfiles.ComputeInstanceProfiles[giProfileId][ciProfileId], MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) CreateComputeInstance(info *ComputeInstanceProfileInfo) (ComputeInstance, Return) {
ciInfo := ComputeInstanceInfo{
Device: gi.Info.Device,
GpuInstance: gi,
Id: gi.ComputeInstanceCounter,
ProfileId: info.Id,
}
gi.ComputeInstanceCounter++
ci := NewMockA100ComputeInstance(ciInfo)
gi.ComputeInstances[ci.(*MockA100ComputeInstance)] = struct{}{}
return ci, MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) GetComputeInstances(info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) {
var cis []ComputeInstance
for ci := range gi.ComputeInstances {
if ci.Info.ProfileId == info.Id {
cis = append(cis, ci)
}
}
return cis, MockReturn(SUCCESS)
}
func (gi *MockA100GpuInstance) Destroy() Return {
delete(gi.Info.Device.(*MockA100Device).GpuInstances, gi)
return MockReturn(SUCCESS)
}
func (ci *MockA100ComputeInstance) GetInfo() (ComputeInstanceInfo, Return) {
return ci.Info, MockReturn(SUCCESS)
}
func (ci *MockA100ComputeInstance) Destroy() Return {
delete(ci.Info.GpuInstance.(*MockA100GpuInstance).ComputeInstances, ci)
return MockReturn(SUCCESS)
}
// Since a compute instance can be used as a MIG device handle, it must also
// implement the Device interface
var _ Device = (*MockA100ComputeInstance)(nil)
func (c *MockA100ComputeInstance) GetIndex() (int, Return) {
return int(c.Info.Id), MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetPciInfo() (PciInfo, Return) {
// TODO: How does this behave on an actual MIG system?
panic("Not implemented: GetPciInfo")
}
func (c *MockA100ComputeInstance) GetUUID() (string, Return) {
return fmt.Sprintf("MIG-%d", c.Info.Id), MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetMinorNumber() (int, Return) {
// TODO: This depends on the content of the mig-minors file and the (gpu, gi, ci) tuple
panic("Not implemented: GetMinorNumber")
}
func (c *MockA100ComputeInstance) SetMigMode(Mode int) (Return, Return) {
panic("Not implemented: SetMigMode")
}
func (c *MockA100ComputeInstance) GetMigMode() (int, int, Return) {
panic("Not implemented: GetMigMode")
}
func (c *MockA100ComputeInstance) GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return) {
panic("Not implemented: GetGpuInstanceProfileInfo")
}
func (c *MockA100ComputeInstance) CreateGpuInstance(Info *GpuInstanceProfileInfo) (GpuInstance, Return) {
panic("Not implemented: CreateGpuInstance")
}
func (c *MockA100ComputeInstance) GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
panic("Not implemented: GetGpuInstances")
}
func (c *MockA100ComputeInstance) GetMaxMigDeviceCount() (int, Return) {
panic("Not implemented: GetMaxMigDeviceCount")
}
func (c *MockA100ComputeInstance) GetMigDeviceHandleByIndex(Index int) (Device, Return) {
panic("Not implemented: GetMigDeviceHandleByIndex")
}
func (c *MockA100ComputeInstance) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
return c.Info.Device, MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) IsMigDeviceHandle() (bool, Return) {
return true, MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetComputeInstanceId() (int, Return) {
return int(c.Info.Id), MockReturn(SUCCESS)
}
func (c *MockA100ComputeInstance) GetGPUInstanceId() (int, Return) {
info, r := c.Info.GpuInstance.GetInfo()
if r.Value() != SUCCESS {
return 0, MockReturn(r.Value())
}
return int(info.Id), MockReturn(SUCCESS)
}

188
internal/nvml/nvml.go Normal file
View File

@@ -0,0 +1,188 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type nvmlLib struct{}
type nvmlDevice nvml.Device
type nvmlGpuInstance nvml.GpuInstance
type nvmlComputeInstance nvml.ComputeInstance
var _ Interface = (*nvmlLib)(nil)
var _ Device = (*nvmlDevice)(nil)
var _ GpuInstance = (*nvmlGpuInstance)(nil)
var _ ComputeInstance = (*nvmlComputeInstance)(nil)
func New() Interface {
return &nvmlLib{}
}
func (n *nvmlLib) Init() Return {
return nvmlReturn(nvml.Init())
}
func (n *nvmlLib) Shutdown() Return {
return nvmlReturn(nvml.Shutdown())
}
func (n *nvmlLib) DeviceGetCount() (int, Return) {
c, r := nvml.DeviceGetCount()
return c, nvmlReturn(r)
}
func (n *nvmlLib) DeviceGetHandleByIndex(index int) (Device, Return) {
d, r := nvml.DeviceGetHandleByIndex(index)
return nvmlDevice(d), nvmlReturn(r)
}
func (n *nvmlLib) SystemGetDriverVersion() (string, Return) {
v, r := nvml.SystemGetDriverVersion()
return v, nvmlReturn(r)
}
func (d nvmlDevice) GetIndex() (int, Return) {
i, r := nvml.Device(d).GetIndex()
return i, nvmlReturn(r)
}
func (d nvmlDevice) GetPciInfo() (PciInfo, Return) {
p, r := nvml.Device(d).GetPciInfo()
return PciInfo(p), nvmlReturn(r)
}
func (d nvmlDevice) GetUUID() (string, Return) {
u, r := nvml.Device(d).GetUUID()
return u, nvmlReturn(r)
}
func (d nvmlDevice) GetMinorNumber() (int, Return) {
m, r := nvml.Device(d).GetMinorNumber()
return m, nvmlReturn(r)
}
func (d nvmlDevice) IsMigDeviceHandle() (bool, Return) {
b, r := nvml.Device(d).IsMigDeviceHandle()
return b, nvmlReturn(r)
}
func (d nvmlDevice) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
p, r := nvml.Device(d).GetDeviceHandleFromMigDeviceHandle()
return nvmlDevice(p), nvmlReturn(r)
}
func (d nvmlDevice) GetGPUInstanceId() (int, Return) {
gi, r := nvml.Device(d).GetGpuInstanceId()
return gi, nvmlReturn(r)
}
func (d nvmlDevice) GetComputeInstanceId() (int, Return) {
ci, r := nvml.Device(d).GetComputeInstanceId()
return ci, nvmlReturn(r)
}
func (d nvmlDevice) SetMigMode(mode int) (Return, Return) {
r1, r2 := nvml.Device(d).SetMigMode(mode)
return nvmlReturn(r1), nvmlReturn(r2)
}
func (d nvmlDevice) GetMigMode() (int, int, Return) {
s1, s2, r := nvml.Device(d).GetMigMode()
return s1, s2, nvmlReturn(r)
}
func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileInfo, Return) {
p, r := nvml.Device(d).GetGpuInstanceProfileInfo(profile)
return GpuInstanceProfileInfo(p), nvmlReturn(r)
}
func (d nvmlDevice) CreateGpuInstance(info *GpuInstanceProfileInfo) (GpuInstance, Return) {
gi, r := nvml.Device(d).CreateGpuInstance((*nvml.GpuInstanceProfileInfo)(info))
return nvmlGpuInstance(gi), nvmlReturn(r)
}
func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
nvmlGis, r := nvml.Device(d).GetGpuInstances((*nvml.GpuInstanceProfileInfo)(info))
var gis []GpuInstance
for _, gi := range nvmlGis {
gis = append(gis, nvmlGpuInstance(gi))
}
return gis, nvmlReturn(r)
}
func (d nvmlDevice) GetMaxMigDeviceCount() (int, Return) {
m, r := nvml.Device(d).GetMaxMigDeviceCount()
return m, nvmlReturn(r)
}
func (d nvmlDevice) GetMigDeviceHandleByIndex(Index int) (Device, Return) {
h, r := nvml.Device(d).GetMigDeviceHandleByIndex(Index)
return nvmlDevice(h), nvmlReturn(r)
}
func (gi nvmlGpuInstance) GetInfo() (GpuInstanceInfo, Return) {
i, r := nvml.GpuInstance(gi).GetInfo()
info := GpuInstanceInfo{
Device: nvmlDevice(i.Device),
Id: i.Id,
ProfileId: i.ProfileId,
Placement: i.Placement,
}
return info, nvmlReturn(r)
}
func (gi nvmlGpuInstance) GetComputeInstanceProfileInfo(profile int, engProfile int) (ComputeInstanceProfileInfo, Return) {
p, r := nvml.GpuInstance(gi).GetComputeInstanceProfileInfo(profile, engProfile)
return ComputeInstanceProfileInfo(p), nvmlReturn(r)
}
func (gi nvmlGpuInstance) CreateComputeInstance(info *ComputeInstanceProfileInfo) (ComputeInstance, Return) {
ci, r := nvml.GpuInstance(gi).CreateComputeInstance((*nvml.ComputeInstanceProfileInfo)(info))
return nvmlComputeInstance(ci), nvmlReturn(r)
}
func (gi nvmlGpuInstance) GetComputeInstances(info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) {
nvmlCis, r := nvml.GpuInstance(gi).GetComputeInstances((*nvml.ComputeInstanceProfileInfo)(info))
var cis []ComputeInstance
for _, ci := range nvmlCis {
cis = append(cis, nvmlComputeInstance(ci))
}
return cis, nvmlReturn(r)
}
func (gi nvmlGpuInstance) Destroy() Return {
r := nvml.GpuInstance(gi).Destroy()
return nvmlReturn(r)
}
func (ci nvmlComputeInstance) GetInfo() (ComputeInstanceInfo, Return) {
i, r := nvml.ComputeInstance(ci).GetInfo()
info := ComputeInstanceInfo{
Device: nvmlDevice(i.Device),
GpuInstance: nvmlGpuInstance(i.GpuInstance),
Id: i.Id,
ProfileId: i.ProfileId,
}
return info, nvmlReturn(r)
}
func (ci nvmlComputeInstance) Destroy() Return {
r := nvml.ComputeInstance(ci).Destroy()
return nvmlReturn(r)
}

110
internal/nvml/return.go Normal file
View File

@@ -0,0 +1,110 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type Return interface {
Value() nvml.Return
String() string
Error() string
}
type nvmlReturn nvml.Return
type MockReturn nvml.Return
var _ Return = (*nvmlReturn)(nil)
var _ Return = (*MockReturn)(nil)
func (r nvmlReturn) Value() nvml.Return {
return nvml.Return(r)
}
func (r nvmlReturn) String() string {
return r.Error()
}
func (r nvmlReturn) Error() string {
return nvml.ErrorString(nvml.Return(r))
}
func (r MockReturn) Value() nvml.Return {
return nvml.Return(r)
}
func (r MockReturn) String() string {
return r.Error()
}
func (r MockReturn) Error() string {
switch nvml.Return(r) {
case SUCCESS:
return "SUCCESS"
case ERROR_UNINITIALIZED:
return "ERROR_UNINITIALIZED"
case ERROR_INVALID_ARGUMENT:
return "ERROR_INVALID_ARGUMENT"
case ERROR_NOT_SUPPORTED:
return "ERROR_NOT_SUPPORTED"
case ERROR_NO_PERMISSION:
return "ERROR_NO_PERMISSION"
case ERROR_ALREADY_INITIALIZED:
return "ERROR_ALREADY_INITIALIZED"
case ERROR_NOT_FOUND:
return "ERROR_NOT_FOUND"
case ERROR_INSUFFICIENT_SIZE:
return "ERROR_INSUFFICIENT_SIZE"
case ERROR_INSUFFICIENT_POWER:
return "ERROR_INSUFFICIENT_POWER"
case ERROR_DRIVER_NOT_LOADED:
return "ERROR_DRIVER_NOT_LOADED"
case ERROR_TIMEOUT:
return "ERROR_TIMEOUT"
case ERROR_IRQ_ISSUE:
return "ERROR_IRQ_ISSUE"
case ERROR_LIBRARY_NOT_FOUND:
return "ERROR_LIBRARY_NOT_FOUND"
case ERROR_FUNCTION_NOT_FOUND:
return "ERROR_FUNCTION_NOT_FOUND"
case ERROR_CORRUPTED_INFOROM:
return "ERROR_CORRUPTED_INFOROM"
case ERROR_GPU_IS_LOST:
return "ERROR_GPU_IS_LOST"
case ERROR_RESET_REQUIRED:
return "ERROR_RESET_REQUIRED"
case ERROR_OPERATING_SYSTEM:
return "ERROR_OPERATING_SYSTEM"
case ERROR_LIB_RM_VERSION_MISMATCH:
return "ERROR_LIB_RM_VERSION_MISMATCH"
case ERROR_IN_USE:
return "ERROR_IN_USE"
case ERROR_MEMORY:
return "ERROR_MEMORY"
case ERROR_NO_DATA:
return "ERROR_NO_DATA"
case ERROR_VGPU_ECC_NOT_SUPPORTED:
return "ERROR_VGPU_ECC_NOT_SUPPORTED"
case ERROR_INSUFFICIENT_RESOURCES:
return "ERROR_INSUFFICIENT_RESOURCES"
case ERROR_UNKNOWN:
return "ERROR_UNKNOWN"
default:
return "Unknown return value"
}
}

78
internal/nvml/types.go Normal file
View File

@@ -0,0 +1,78 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type Interface interface {
Init() Return
Shutdown() Return
DeviceGetCount() (int, Return)
DeviceGetHandleByIndex(Index int) (Device, Return)
SystemGetDriverVersion() (string, Return)
}
type Device interface {
GetIndex() (int, Return)
GetPciInfo() (PciInfo, Return)
GetUUID() (string, Return)
GetMinorNumber() (int, Return)
IsMigDeviceHandle() (bool, Return)
GetDeviceHandleFromMigDeviceHandle() (Device, Return)
SetMigMode(Mode int) (Return, Return)
GetMigMode() (int, int, Return)
GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return)
CreateGpuInstance(Info *GpuInstanceProfileInfo) (GpuInstance, Return)
GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return)
GetMaxMigDeviceCount() (int, Return)
GetMigDeviceHandleByIndex(Index int) (Device, Return)
GetGPUInstanceId() (int, Return)
GetComputeInstanceId() (int, Return)
}
type GpuInstance interface {
GetInfo() (GpuInstanceInfo, Return)
GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return)
CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return)
GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return)
Destroy() Return
}
type ComputeInstance interface {
GetInfo() (ComputeInstanceInfo, Return)
Destroy() Return
}
type GpuInstanceInfo struct {
Device Device
Id uint32
ProfileId uint32
Placement nvml.GpuInstancePlacement
}
type ComputeInstanceInfo struct {
Device Device
GpuInstance GpuInstance
Id uint32
ProfileId uint32
}
type PciInfo nvml.PciInfo
type GpuInstanceProfileInfo nvml.GpuInstanceProfileInfo
type ComputeInstanceProfileInfo nvml.ComputeInstanceProfileInfo

116
internal/proc/devices.go Normal file
View File

@@ -0,0 +1,116 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package proc
import (
"bufio"
"fmt"
"io"
"log"
"os"
"strings"
)
const (
procDevicesPath = "/proc/devices"
nvidiaDevicePrefix = "nvidia"
)
// Device represents a device as specified under /proc/devices
type Device struct {
Name string
Major int
}
// NvidiaDevices represents the set of nvidia owned devices under /proc/devices
type NvidiaDevices interface {
Exists(name string) bool
Get(name string) (Device, bool)
}
type nvidiaDevices map[string]Device
var _ NvidiaDevices = nvidiaDevices(nil)
// Exists checks if a Device with a given name exists or not
func (d nvidiaDevices) Exists(name string) bool {
_, exists := d[name]
return exists
}
// Get a Device from NvidiaDevices
func (d nvidiaDevices) Get(name string) (Device, bool) {
device, exists := d[name]
return device, exists
}
func (d nvidiaDevices) add(devices ...Device) {
for _, device := range devices {
d[device.Name] = device
}
}
// NewMockNvidiaDevices returns NvidiaDevices populated from the devices passed in
func NewMockNvidiaDevices(devices ...Device) NvidiaDevices {
nvds := make(nvidiaDevices)
nvds.add(devices...)
return nvds
}
// GetNvidiaDevices returns the set of NvidiaDevices on the machine
func GetNvidiaDevices() (NvidiaDevices, error) {
devicesFile, err := os.Open(procDevicesPath)
if os.IsNotExist(err) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("error opening devices file: %v", err)
}
defer devicesFile.Close()
return processDeviceFile(devicesFile), nil
}
func processDeviceFile(devicesFile io.Reader) NvidiaDevices {
nvidiaDevices := make(nvidiaDevices)
scanner := bufio.NewScanner(devicesFile)
for scanner.Scan() {
device, major, err := processProcDeviceLine(scanner.Text())
if err != nil {
log.Printf("Skipping line in devices file: %v", err)
continue
}
if strings.HasPrefix(device, nvidiaDevicePrefix) {
nvidiaDevices.add(Device{device, major})
}
}
return nvidiaDevices
}
func processProcDeviceLine(line string) (string, int, error) {
trimmed := strings.TrimSpace(line)
var name string
var major int
n, _ := fmt.Sscanf(trimmed, "%d %s", &major, &name)
if n == 2 {
return name, major, nil
}
return "", 0, fmt.Errorf("unparsable line: %v", line)
}

View File

@@ -0,0 +1,92 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package proc
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestNvidiaDevices(t *testing.T) {
devices := []Device{
{"nvidia-frontend", 195},
{"nvidia-nvlink", 234},
{"nvidia-caps", 235},
{"nvidia-uvm", 510},
{"nvidia-nvswitch", 511},
}
nvidiaDevices := NewMockNvidiaDevices(devices...)
for _, d := range devices {
device, exists := nvidiaDevices.Get(d.Name)
require.True(t, exists, "Unexpected missing device")
require.Equal(t, device.Name, d.Name, "Unexpected device name")
require.Equal(t, device.Major, d.Major, "Unexpected device major")
}
_, exists := nvidiaDevices.Get("bogus")
require.False(t, exists, "Unexpected 'bogus' device found")
}
func TestProcessDeviceFile(t *testing.T) {
testCases := []struct {
lines []string
expected []Device
}{
{[]string{}, []Device{}},
{[]string{"Not a valid line:"}, []Device{}},
{[]string{"195 nvidia-frontend"}, []Device{{"nvidia-frontend", 195}}},
{[]string{"195 nvidia-frontend", "235 nvidia-caps"}, []Device{{"nvidia-frontend", 195}, {"nvidia-caps", 235}}},
{[]string{" 195 nvidia-frontend"}, []Device{{"nvidia-frontend", 195}}},
{[]string{"Not a valid line:", "", "195 nvidia-frontend"}, []Device{{"nvidia-frontend", 195}}},
{[]string{"195 not-nvidia-frontend"}, []Device{}},
}
for _, tc := range testCases {
contents := strings.NewReader(strings.Join(tc.lines, "\n"))
d := processDeviceFile(contents)
require.Equalf(t, NewMockNvidiaDevices(tc.expected...), d, "testCase: %v", tc)
}
}
func TestProcessDeviceFileLine(t *testing.T) {
testCases := []struct {
line string
name string
major int
err bool
}{
{"", "", 0, true},
{"0", "", 0, true},
{"notint nvidia-frontend", "", 0, true},
{"195 nvidia-frontend", "nvidia-frontend", 195, false},
{" 195 nvidia-frontend", "nvidia-frontend", 195, false},
}
for _, tc := range testCases {
name, major, err := processProcDeviceLine(tc.line)
require.Equal(t, tc.name, name)
require.Equal(t, tc.major, major)
if tc.err {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
}

51
pkg/discover/binaries.go Normal file
View File

@@ -0,0 +1,51 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
)
// NewBinaryMounts creates a discoverer for binaries using the specified root
func NewBinaryMounts(root string) Discover {
return NewBinaryMountsWithLogger(log.StandardLogger(), root)
}
// NewBinaryMountsWithLogger creates a Mounts discoverer as with NewBinaryMounts
// with the specified logger
func NewBinaryMountsWithLogger(logger *log.Logger, root string) Discover {
d := mounts{
logger: logger,
lookup: lookup.NewPathLocatorWithLogger(logger, root),
required: requiredBinaries,
}
return &d
}
// requiredBinaries defines a set of binaries and their labels
var requiredBinaries = map[string][]string{
"utility": {
"nvidia-smi", /* System management interface */
"nvidia-debugdump", /* GPU coredump utility */
"nvidia-persistenced", /* Persistence mode utility */
},
"compute": {
"nvidia-cuda-mps-control", /* Multi process service CLI */
"nvidia-cuda-mps-server", /* Multi process service server */
},
}

View File

@@ -0,0 +1,73 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestBinaries(t *testing.T) {
binaryLookup := map[string]string{
"nvidia-smi": "/usr/bin/nvidia-smi",
"nvidia-persistenced": "/usr/bin/nvidia-persistenced",
"nvidia-debugdump": "test-duplicates",
"nvidia-cuda-mps-control": "test-duplicates",
}
logger, _ := testlog.NewNullLogger()
d := NewBinaryMountsWithLogger(logger, "")
// Override lookup for testing
mockLocator := NewLocatorMockFromMap(binaryLookup)
d.(*mounts).lookup = mockLocator
mounts, err := d.Mounts()
require.NoError(t, err)
expected := []Mount{
{
Path: "/usr/bin/nvidia-smi",
},
{
Path: "/usr/bin/nvidia-persistenced",
},
{
Path: "test-duplicates",
},
}
require.Equal(t, len(expected), len(mounts))
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}
func TestNewBinariesConstructor(t *testing.T) {
b := NewBinaryMounts("").(*mounts)
require.NotNil(t, b.logger)
require.NotNil(t, b.lookup)
}

74
pkg/discover/composite.go Normal file
View File

@@ -0,0 +1,74 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import "fmt"
// composite is a discoverer that contains a list of Discoverers. The output of the
// Devices, Mounts, and Hooks functions is the concatenation of the output for each of the
// elements in the list.
type composite struct {
discoverers []Discover
}
var _ Discover = (*composite)(nil)
func (d composite) Devices() ([]Device, error) {
var allDevices []Device
for i, di := range d.discoverers {
devices, err := di.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices for discoverer %v: %v", i, err)
}
allDevices = append(allDevices, devices...)
}
return allDevices, nil
}
func (d composite) Mounts() ([]Mount, error) {
var allMounts []Mount
for i, di := range d.discoverers {
mounts, err := di.Mounts()
if err != nil {
return nil, fmt.Errorf("error discovering mounts for discoverer %v: %v", i, err)
}
allMounts = append(allMounts, mounts...)
}
return allMounts, nil
}
func (d composite) Hooks() ([]Hook, error) {
var allHooks []Hook
for i, di := range d.discoverers {
hooks, err := di.Hooks()
if err != nil {
return nil, fmt.Errorf("error discovering hooks for discoverer %v: %v", i, err)
}
allHooks = append(allHooks, hooks...)
}
return allHooks, nil
}
func (d *composite) add(di ...Discover) {
d.discoverers = append(d.discoverers, di...)
}

65
pkg/discover/discover.go Normal file
View File

@@ -0,0 +1,65 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
// DevicePath is a path in /dev associated with a device
type DevicePath string
// ProcPath is a path in /proc associated with a devices
type ProcPath string
// PCIBusID is the ID on the PCI bus of a device
type PCIBusID string
// DeviceNode represents a device on the file system
type DeviceNode struct {
Path DevicePath
Major int
Minor int
}
// Device represents a discovered device including identifiers (Index, UUID, PCI bus ID)
// for selection and paths in /dev and /proc associated with the device
type Device struct {
Index string
UUID string
PCIBusID PCIBusID
DeviceNodes []DeviceNode
ProcPaths []ProcPath
}
// Mount represents a discovered mount. This includes a set of labels
// for selection and the mount path
type Mount struct {
Path string
Labels map[string]string
}
// Hook represents a discovered hook
type Hook struct {
Path string
Args []string
HookName string
Labels map[string]string
}
// Discover defines an interface for discovering the devices and mounts available on a system
type Discover interface {
Devices() ([]Device, error)
Mounts() ([]Mount, error)
Hooks() ([]Hook, error)
}

View File

@@ -0,0 +1,424 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/proc"
)
const (
// ControlDeviceUUID is used as the UUID for control devices such as nvidiactl or nvidia-modeset
ControlDeviceUUID = "CONTROL"
// MIGConfigDeviceUUID is used to indicate the MIG config control device
MIGConfigDeviceUUID = "CONFIG"
// MIGMonitorDeviceUUID is used to indicate the MIG monitor control device
MIGMonitorDeviceUUID = "MONITOR"
nvidiaGPUDeviceName = "nvidia-frontend"
nvidiaCapsDeviceName = "nvidia-caps"
nvidiaUVMDeviceName = "nvidia-uvm"
)
type nvmlDiscover struct {
None
logger *log.Logger
nvml nvml.Interface
migCaps map[ProcPath]DeviceNode
nvidiaDevices proc.NvidiaDevices
}
var _ Discover = (*nvmlDiscover)(nil)
// NewNVMLDiscover constructs a discoverer that uses NVML to find the devices
// available on a system.
func NewNVMLDiscover(nvml nvml.Interface) (Discover, error) {
return NewNVMLDiscoverWithLogger(log.StandardLogger(), nvml)
}
// NewNVMLDiscoverWithLogger constructs a discovered as with NewNVMLDiscover with the specified
// logger
func NewNVMLDiscoverWithLogger(logger *log.Logger, nvml nvml.Interface) (Discover, error) {
nvidiaDevices, err := proc.GetNvidiaDevices()
if err != nil {
return nil, fmt.Errorf("error loading NVIDIA devices: %v", err)
}
var migCaps map[ProcPath]DeviceNode
nvcapsDevice, exists := nvidiaDevices.Get(nvidiaCapsDeviceName)
if !exists {
logger.Warnf("%v nvcaps device could not be found", nvidiaCapsDeviceName)
} else if migCaps, err = getMigCaps(nvcapsDevice.Major); err != nil {
logger.Warnf("Could not load MIG capability devices: %v", err)
migCaps = nil
}
discover := &nvmlDiscover{
logger: logger,
nvml: nvml,
migCaps: migCaps,
nvidiaDevices: nvidiaDevices,
}
return discover, nil
}
// hasMigSupport checks if MIG device discovery is supported.
// Cases where this will be disabled include where no MIG minors file is
// present.
func (d nvmlDiscover) hasMigSupport() bool {
return len(d.migCaps) > 0
}
func (d *nvmlDiscover) Devices() ([]Device, error) {
ret := d.nvml.Init()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error initalizing NVML: %v", ret.Error())
}
defer d.tryShutdownNVML()
c, ret := d.nvml.DeviceGetCount()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device count: %v", ret.Error())
}
var handles []nvml.Device
for i := 0; i < c; i++ {
handle, ret := d.nvml.DeviceGetHandleByIndex(i)
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for device %v: %v", i, ret.Error())
}
if !d.hasMigSupport() {
handles = append(handles, handle)
continue
}
migHandles, err := getMIGHandlesForDevice(handle)
if err != nil {
return nil, fmt.Errorf("error getting MIG handles for device: %v", err)
}
if len(migHandles) == 0 {
handles = append(handles, handle)
}
handles = append(handles, migHandles...)
}
return d.devicesByHandle(handles)
}
func (d *nvmlDiscover) devicesByHandle(handles []nvml.Device) ([]Device, error) {
var devices []Device
var largestMinorNumber int
for _, h := range handles {
device, err := d.deviceFromNVMLHandle(h)
if err != nil {
return nil, fmt.Errorf("error constructing device from handle %v: %v", h, err)
}
devices = append(devices, device)
if largestMinorNumber < device.DeviceNodes[0].Minor {
largestMinorNumber = device.DeviceNodes[0].Minor
}
}
controlDevices, err := d.getControlDevices()
if err != nil {
return nil, fmt.Errorf("error getting control devices: %v", err)
}
devices = append(devices, controlDevices...)
if d.hasMigSupport() {
migControlDevices, err := d.getMigControlDevices(largestMinorNumber)
if err != nil {
return nil, fmt.Errorf("error getting MIG control devices: %v", err)
}
devices = append(devices, migControlDevices...)
}
return devices, nil
}
func (d *nvmlDiscover) deviceFromNVMLHandle(handle nvml.Device) (Device, error) {
if d.hasMigSupport() {
isMigDevice, ret := handle.IsMigDeviceHandle()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error checking device handle: %v", ret.Error())
}
if isMigDevice {
return d.deviceFromMIGDeviceHandle(handle)
}
}
return d.deviceFromFullDeviceHandle(handle)
}
func (d *nvmlDiscover) deviceFromFullDeviceHandle(handle nvml.Device) (Device, error) {
index, ret := handle.GetIndex()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
}
uuid, ret := handle.GetUUID()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
}
pciInfo, ret := handle.GetPciInfo()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting PCI info: %v", ret.Error())
}
busID := NewPCIBusID(pciInfo)
minor, ret := handle.GetMinorNumber()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting minor number: %v", ret.Error())
}
nvidiaGPUDevice, exists := d.nvidiaDevices.Get(nvidiaGPUDeviceName)
if !exists {
return Device{}, fmt.Errorf("device for '%v' does not exist", nvidiaGPUDeviceName)
}
deviceNode := DeviceNode{
Path: DevicePath(fmt.Sprintf("/dev/nvidia%d", minor)),
Major: nvidiaGPUDevice.Major,
Minor: minor,
}
device := Device{
Index: fmt.Sprintf("%d", index),
PCIBusID: busID,
UUID: uuid,
DeviceNodes: []DeviceNode{deviceNode},
ProcPaths: []ProcPath{busID.GetProcPath()},
}
return device, nil
}
func (d *nvmlDiscover) deviceFromMIGDeviceHandle(handle nvml.Device) (Device, error) {
parent, ret := handle.GetDeviceHandleFromMigDeviceHandle()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting parent device handle: %v", ret.Error())
}
gpu, ret := parent.GetMinorNumber()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting GPU minor number: %v", ret.Error())
}
parentDevice, err := d.deviceFromFullDeviceHandle(parent)
if err != nil {
return Device{}, fmt.Errorf("error getting parent device: %v", err)
}
index, ret := handle.GetIndex()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device index: %v", ret.Error())
}
uuid, ret := handle.GetUUID()
if ret.Value() != nvml.SUCCESS {
return Device{}, fmt.Errorf("error getting device UUID: %v", ret.Error())
}
capDeviceNodes := []DeviceNode{}
procPaths, err := getProcPathsForMigDevice(gpu, handle)
if err != nil {
return Device{}, fmt.Errorf("error getting proc paths for MIG device: %v", err)
}
for _, p := range procPaths {
capDeviceNode, ok := d.migCaps[p]
if !ok {
return Device{}, fmt.Errorf("could not determine cap device path for %v", p)
}
capDeviceNodes = append(capDeviceNodes, capDeviceNode)
}
device := Device{
Index: fmt.Sprintf("%s:%d", parentDevice.Index, index),
UUID: uuid,
DeviceNodes: append(parentDevice.DeviceNodes, capDeviceNodes...),
ProcPaths: append(parentDevice.ProcPaths, procPaths...),
}
return device, nil
}
func (d *nvmlDiscover) getControlDevices() ([]Device, error) {
devices := []struct {
name string
path string
minor int
}{
// TODO: Where is the best place to find these device Minors programatically?
{nvidiaGPUDeviceName, "/dev/nvidia-modeset", 254},
{nvidiaGPUDeviceName, "/dev/nvidiactl", 255},
{nvidiaUVMDeviceName, "/dev/nvidia-uvm", 0},
{nvidiaUVMDeviceName, "/dev/nvidia-uvm-tools", 1},
}
var controlDevices []Device
for _, info := range devices {
device, exists := d.nvidiaDevices.Get(info.name)
if !exists {
d.logger.Warnf("device name %v not defined; skipping control devices %v", info.name, info.path)
continue
}
deviceNode := DeviceNode{
Path: DevicePath(info.path),
Major: device.Major,
Minor: info.minor,
}
controlDevices = append(controlDevices, Device{
UUID: ControlDeviceUUID,
DeviceNodes: []DeviceNode{deviceNode},
ProcPaths: []ProcPath{},
})
}
return controlDevices, nil
}
func (d *nvmlDiscover) getMigControlDevices(numGpus int) ([]Device, error) {
targets := map[string]ProcPath{
MIGConfigDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/config"),
MIGMonitorDeviceUUID: ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"),
}
var devices []Device
for id, procPath := range targets {
deviceNode, exists := d.migCaps[procPath]
if !exists {
return nil, fmt.Errorf("device node for '%v' is undefined", procPath)
}
var procPaths []ProcPath
for gpu := 0; gpu <= numGpus; gpu++ {
procPaths = append(procPaths, ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig", gpu)))
}
procPaths = append(procPaths, procPath)
devices = append(devices, Device{
UUID: id,
DeviceNodes: []DeviceNode{deviceNode},
ProcPaths: procPaths,
})
}
return devices, nil
}
func getProcPathsForMigDevice(gpu int, handle nvml.Device) ([]ProcPath, error) {
gi, ret := handle.GetGPUInstanceId()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU instance ID: %v", ret.Error())
}
ci, ret := handle.GetComputeInstanceId()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting comput instance ID: %v", ret.Error())
}
procPaths := []ProcPath{
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/access", gpu, gi)),
ProcPath(fmt.Sprintf("/proc/driver/nvidia/capabilities/gpu%d/mig/gi%d/ci%d/access", gpu, gi, ci)),
}
return procPaths, nil
}
func getMIGHandlesForDevice(handle nvml.Device) ([]nvml.Device, error) {
currentMigMode, _, ret := handle.GetMigMode()
if ret.Value() == nvml.ERROR_NOT_SUPPORTED {
return nil, nil
}
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG mode for device: %v", ret.Error())
}
if currentMigMode == nvml.DEVICE_MIG_DISABLE {
return nil, nil
}
maxMigDeviceCount, ret := handle.GetMaxMigDeviceCount()
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting number of MIG devices: %v", ret.Error())
}
var migHandles []nvml.Device
for mi := 0; mi < maxMigDeviceCount; mi++ {
migHandle, ret := handle.GetMigDeviceHandleByIndex(mi)
if ret.Value() == nvml.ERROR_NOT_FOUND {
continue
}
if ret.Value() != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device %v: %v", mi, ret.Error())
}
migHandles = append(migHandles, migHandle)
}
return migHandles, nil
}
func (d *nvmlDiscover) tryShutdownNVML() {
ret := d.nvml.Shutdown()
if ret.Value() != nvml.SUCCESS {
d.logger.Warnf("Could not shut down NVML: %v", ret.Error())
}
}
// NewPCIBusID provides a utility function that returns the string representation
// of the bus ID.
func NewPCIBusID(p nvml.PciInfo) PCIBusID {
var bytes []byte
for _, b := range p.BusId {
if byte(b) == '\x00' {
break
}
bytes = append(bytes, byte(b))
}
return PCIBusID(string(bytes))
}
// GetProcPath returns the path in /proc associated with the PCI bus ID
func (p PCIBusID) GetProcPath() ProcPath {
id := strings.ToLower(p.String())
if strings.HasPrefix(id, "0000") {
id = id[4:]
}
return ProcPath(filepath.Join("/proc/driver/nvidia/gpus", id))
}
func (p PCIBusID) String() string {
return string(p)
}

View File

@@ -0,0 +1,44 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvcaps"
)
// getMigCaps returns a mapping of MIG capability path to device nodes
func getMigCaps(capDeviceMajor int) (map[ProcPath]DeviceNode, error) {
migCaps, err := nvcaps.LoadMigMinors()
if err != nil {
return nil, fmt.Errorf("error loading MIG minors: %v", err)
}
return getMigCapsFromMigMinors(migCaps, capDeviceMajor), nil
}
func getMigCapsFromMigMinors(migCaps map[nvcaps.MigCap]nvcaps.MigMinor, capDeviceMajor int) map[ProcPath]DeviceNode {
capsDevicePaths := make(map[ProcPath]DeviceNode)
for cap, minor := range migCaps {
capsDevicePaths[ProcPath(cap.ProcPath())] = DeviceNode{
Path: DevicePath(minor.DevicePath()),
Major: capDeviceMajor,
Minor: int(minor),
}
}
return capsDevicePaths
}

View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvcaps"
)
func TestGetMigCaps(t *testing.T) {
migMinors := map[nvcaps.MigCap]nvcaps.MigMinor{
"config": 1,
"monitor": 2,
"gpu0/gi0/access": 3,
"gpu0/gi0/ci0/access": 4,
}
migCapMajor := 999
migCaps := getMigCapsFromMigMinors(migMinors, migCapMajor)
require.Len(t, migCaps, len(migMinors))
for _, c := range migCaps {
require.Equal(t, migCapMajor, c.Major)
}
}

View File

@@ -0,0 +1,219 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/proc"
)
const (
nvidiaGPUDeviceMajorDefault = 195
nvidiaCapsDeviceMajorDefault = 235
)
func newTestDiscover() nvmlDiscover {
logger, _ := testlog.NewNullLogger()
nvml := nvml.NewMockNVMLServer(nvml.NewMockA100Device(0))
return nvmlDiscover{
logger: logger,
nvml: nvml,
nvidiaDevices: proc.NewMockNvidiaDevices(
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
),
}
}
func newTestDiscoverWithMIG() nvmlDiscover {
device := &nvml.MockA100Device{
Index: 0,
MigMode: nvml.DEVICE_MIG_ENABLE,
GpuInstances: make(map[*nvml.MockA100GpuInstance]struct{}),
GpuInstanceCounter: 0,
}
// Create a single gi and ci on the device
gpuInstanceProfileInfo := &nvml.GpuInstanceProfileInfo{
Id: nvml.GPU_INSTANCE_PROFILE_7_SLICE,
}
gi, _ := device.CreateGpuInstance(gpuInstanceProfileInfo)
computeInstanceProfileInfo := &nvml.ComputeInstanceProfileInfo{
Id: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE,
}
_, _ = gi.CreateComputeInstance(computeInstanceProfileInfo)
logger, _ := testlog.NewNullLogger()
return nvmlDiscover{
logger: logger,
nvml: nvml.NewMockNVMLServer(device),
migCaps: map[ProcPath]DeviceNode{
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
Minor: 3,
Major: 235,
},
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
Minor: 4,
Major: 235,
},
ProcPath("/proc/driver/nvidia/capabilities/mig/config"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
Minor: 1,
Major: 235,
},
ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"): {
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
Minor: 2,
Major: 235,
},
},
nvidiaDevices: proc.NewMockNvidiaDevices(
proc.Device{Name: nvidiaGPUDeviceName, Major: nvidiaGPUDeviceMajorDefault},
proc.Device{Name: nvidiaCapsDeviceName, Major: nvidiaCapsDeviceMajorDefault},
),
}
}
func TestDiscoverNvmlDevices(t *testing.T) {
d := newTestDiscover()
devices, err := d.Devices()
require.NoError(t, err)
require.Len(t, devices, 3)
device := devices[0]
require.Equal(t, "0", device.Index)
require.Equal(t, "GPU-0", device.UUID)
require.Equal(t, "0000FFFF:FF:FF.F", device.PCIBusID.String())
expectedDeviceNodes := []DeviceNode{
{
Path: DevicePath("/dev/nvidia0"),
Minor: 0,
Major: nvidiaGPUDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, device.DeviceNodes)
expectedProcPaths := []ProcPath{ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f")}
require.Equal(t, expectedProcPaths, device.ProcPaths)
}
func TestDiscoverNvmlMigDevices(t *testing.T) {
d := newTestDiscoverWithMIG()
devices, err := d.Devices()
require.NoError(t, err)
require.Len(t, devices, 5)
mig := devices[0]
require.Equal(t, "0:0", mig.Index)
require.Equal(t, "MIG-0", mig.UUID)
require.Empty(t, mig.PCIBusID)
expectedDeviceNodes := []DeviceNode{
{
Path: DevicePath("/dev/nvidia0"),
Minor: 0,
Major: nvidiaGPUDeviceMajorDefault,
},
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap3"),
Minor: 3,
Major: nvidiaCapsDeviceMajorDefault,
},
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap4"),
Minor: 4,
Major: nvidiaCapsDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, mig.DeviceNodes)
expectedProcPaths := []ProcPath{
ProcPath("/proc/driver/nvidia/gpus/ffff:ff:ff.f"),
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/access"),
ProcPath("/proc/driver/nvidia/capabilities/gpu0/mig/gi0/ci0/access"),
}
require.Equal(t, expectedProcPaths, mig.ProcPaths)
var config *Device
var monitor *Device
for i, d := range devices {
if d.UUID == "CONFIG" {
config = &devices[i]
}
if d.UUID == "MONITOR" {
monitor = &devices[i]
}
}
require.NotNil(t, config)
require.NotNil(t, monitor)
require.Equal(t, "CONFIG", config.UUID)
expectedDeviceNodes = []DeviceNode{
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap1"),
Minor: 1,
Major: nvidiaCapsDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, config.DeviceNodes)
require.Contains(t, config.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/config"))
require.Len(t, config.ProcPaths, 2)
require.Equal(t, "MONITOR", monitor.UUID)
expectedDeviceNodes = []DeviceNode{
{
Path: DevicePath("/dev/nvidia-caps/nvidia-cap2"),
Minor: 2,
Major: nvidiaCapsDeviceMajorDefault,
},
}
require.Equal(t, expectedDeviceNodes, monitor.DeviceNodes)
require.Contains(t, monitor.ProcPaths, ProcPath("/proc/driver/nvidia/capabilities/mig/monitor"))
require.Len(t, monitor.ProcPaths, 2)
}
func TestPCIBusID(t *testing.T) {
testCases := map[string]ProcPath{
"0000FFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffff:ff:ff.f",
"FFFFFFFF:FF:FF.F": "/proc/driver/nvidia/gpus/ffffffff:ff:ff.f",
}
for busID, procPath := range testCases {
p := PCIBusID(busID)
require.Equal(t, busID, p.String())
require.Equal(t, procPath, p.GetProcPath())
}
}

71
pkg/discover/hooks.go Normal file
View File

@@ -0,0 +1,71 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
log "github.com/sirupsen/logrus"
)
type hooks struct {
None
logger *log.Logger
}
var _ Discover = (*hooks)(nil)
// NewHooks creates a discoverer for linux containers
func NewHooks() Discover {
return NewHooksWithLogger(log.StandardLogger())
}
// NewHooksWithLogger creates a discoverer as with NewHooks with the specified logger
func NewHooksWithLogger(logger *log.Logger) Discover {
h := hooks{
logger: logger,
}
return &h
}
func (h hooks) Hooks() ([]Hook, error) {
var hooks []Hook
hooks = append(hooks, newLdconfigHook())
return hooks, nil
}
func newLdconfigHook() Hook {
const rootPattern = "@Root.Path@"
h := Hook{
Path: "/sbin/ldconfig",
Args: []string{
// TODO: Testing seems to indicate that this is -v flag is required
"-v",
"-r", rootPattern,
},
// TODO: CreateContainer hooks were only added to a later OCI spec version
// We will have to find a way to deal with OCI versions before 1.0.2
HookName: "create-container",
Labels: map[string]string{
"min-oci-version": "1.0.2",
},
}
return h
}

52
pkg/discover/ipc.go Normal file
View File

@@ -0,0 +1,52 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
)
// NewIPCMounts creates a discoverer for IPC sockets
func NewIPCMounts(root string) Discover {
return NewIPCMountsWithLogger(log.StandardLogger(), root)
}
// NewIPCMountsWithLogger creates a discovered as with NewIPCMounts with the
// specified logger.
func NewIPCMountsWithLogger(logger *log.Logger, root string) Discover {
d := mounts{
logger: logger,
lookup: lookup.NewFileLocatorWithLogger(logger, root),
required: requiredIPCs,
}
return &d
}
var requiredIPCs = map[string][]string{
"nvidia-persistenced": {
"/var/run/nvidia-persistenced/socket",
},
"nvidia-fabricmanager": {
"/var/run/nvidia-fabricmanager/socket",
},
// TODO: This can be controlled by the NV_MPS_PIPE_DIR envvar
"nvidia-mps": {
"/tmp/nvidia-mps",
},
}

56
pkg/discover/ipc_test.go Normal file
View File

@@ -0,0 +1,56 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestIPCDiscover(t *testing.T) {
ipcLookup := map[string]string{
"/var/run/nvidia-persistenced/socket": "/var/run/nvidia-persistenced/socket",
"/var/run/nvidia-fabricmanager/socket": "fm-socket",
}
logger, _ := testlog.NewNullLogger()
d := NewIPCMountsWithLogger(logger, "")
// Override lookup for testing
mockLocator := NewLocatorMockFromMap(ipcLookup)
d.(*mounts).lookup = mockLocator
mounts, err := d.Mounts()
require.NoError(t, err)
require.ElementsMatch(t, []Mount{
{Path: "/var/run/nvidia-persistenced/socket", Labels: map[string]string{}},
{Path: "fm-socket", Labels: map[string]string{}}}, mounts)
require.Len(t, mockLocator.LocateCalls(), 3)
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}

100
pkg/discover/libraries.go Normal file
View File

@@ -0,0 +1,100 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
)
// NewLibraries constructs discoverer for libraries
func NewLibraries(root string) (Discover, error) {
return NewLibrariesWithLogger(log.StandardLogger(), root)
}
// NewLibrariesWithLogger constructs discoverer for libraries with the specified logger
func NewLibrariesWithLogger(logger *log.Logger, root string) (Discover, error) {
lookup, err := lookup.NewLibraryLocatorWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error constructing locator: %v", err)
}
d := mounts{
logger: logger,
lookup: lookup,
required: requiredLibraries,
}
return &d, nil
}
// requiredLibraries defines a set of libraries and their labels
var requiredLibraries = map[string][]string{
"utility": {
"libnvidia-ml.so", /* Management library */
"libnvidia-cfg.so", /* GPU configuration */
},
"compute": {
"libcuda.so", /* CUDA driver library */
"libnvidia-opencl.so", /* NVIDIA OpenCL ICD */
"libnvidia-ptxjitcompiler.so", /* PTX-SASS JIT compiler (used by libcuda) */
"libnvidia-fatbinaryloader.so", /* fatbin loader (used by libcuda) */
"libnvidia-allocator.so", /* NVIDIA allocator runtime library */
"libnvidia-compiler.so", /* NVVM-PTX compiler for OpenCL (used by libnvidia-opencl) */
},
"video": {
"libvdpau_nvidia.so", /* NVIDIA VDPAU ICD */
"libnvidia-encode.so", /* Video encoder */
"libnvidia-opticalflow.so", /* NVIDIA Opticalflow library */
"libnvcuvid.so", /* Video decoder */
},
"graphics": {
//"libnvidia-egl-wayland.so", /* EGL wayland platform extension (used by libEGL_nvidia) */
"libnvidia-eglcore.so", /* EGL core (used by libGLES*[_nvidia] and libEGL_nvidia) */
"libnvidia-glcore.so", /* OpenGL core (used by libGL or libGLX_nvidia) */
"libnvidia-tls.so", /* Thread local storage (used by libGL or libGLX_nvidia) */
"libnvidia-glsi.so", /* OpenGL system interaction (used by libEGL_nvidia) */
"libnvidia-fbc.so", /* Framebuffer capture */
"libnvidia-ifr.so", /* OpenGL framebuffer capture */
"libnvidia-rtcore.so", /* Optix */
"libnvoptix.so", /* Optix */
},
"glvnd": {
//"libGLX.so", /* GLX ICD loader */
//"libOpenGL.so", /* OpenGL ICD loader */
//"libGLdispatch.so", /* OpenGL dispatch (used by libOpenGL, libEGL and libGLES*) */
"libGLX_nvidia.so", /* OpenGL/GLX ICD */
"libEGL_nvidia.so", /* EGL ICD */
"libGLESv2_nvidia.so", /* OpenGL ES v2 ICD */
"libGLESv1_CM_nvidia.so", /* OpenGL ES v1 common profile ICD */
"libnvidia-glvkspirv.so", /* SPIR-V Lib for Vulkan */
"libnvidia-cbl.so", /* VK_NV_ray_tracing */
},
"compat32": {
"libGL.so", /* OpenGL/GLX legacy _or_ compatibility wrapper (GLVND) */
"libEGL.so", /* EGL legacy _or_ ICD loader (GLVND) */
"libGLESv1_CM.so", /* OpenGL ES v1 common profile legacy _or_ ICD loader (GLVND) */
"libGLESv2.so", /* OpenGL ES v2 legacy _or_ ICD loader (GLVND) */
},
"ngx": {
"libnvidia-ngx.so", /* NGX library */
},
"dxcore": {
"libdxcore.so", /* Core library for dxcore support */
},
}

View File

@@ -0,0 +1,56 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestLibraries(t *testing.T) {
libraryLookup := map[string]string{
"libcuda.so": "/lib/libcuda.so.999.99",
"libversion.so": "/lib/libversion.so.111.11",
"libother.so": "/lib/libother.so.999.99",
}
logger, _ := testlog.NewNullLogger()
d, err := NewLibrariesWithLogger(logger, "")
require.NoError(t, err)
// Override lookup for testing
mockLocator := NewLocatorMockFromMap(libraryLookup)
d.(*mounts).lookup = mockLocator
mounts, err := d.Mounts()
require.NoError(t, err)
require.ElementsMatch(t, []Mount{{Path: "/lib/libcuda.so.999.99", Labels: map[string]string{}}}, mounts)
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}

125
pkg/discover/mounts.go Normal file
View File

@@ -0,0 +1,125 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
)
const (
capabilityLabel = "capability"
versionLabel = "version"
)
// mounts is a generic discoverer for Mounts. It is customized by specifying the
// required entities as a key-value pair as well as a Locator that is used to
// identify the mounts that are to be included.
type mounts struct {
None
logger *log.Logger
lookup lookup.Locator
required map[string][]string
}
var _ Discover = (*mounts)(nil)
func (d mounts) Mounts() ([]Mount, error) {
mounts, err := d.uniqueMounts()
if err != nil {
return nil, fmt.Errorf("error discovering mounts: %v", err)
}
return mounts.Slice(), nil
}
func (d mounts) uniqueMounts() (mountsByPath, error) {
if d.lookup == nil {
return nil, fmt.Errorf("no lookup defined")
}
mounts := make(mountsByPath)
for id, keys := range d.required {
for _, key := range keys {
d.logger.Debugf("Locating %v [%v]", key, id)
located, err := d.lookup.Locate(key)
if err != nil {
d.logger.Warnf("Could not locate %v [%v]: %v", key, id, err)
continue
}
d.logger.Infof("Located %v [%v]: %v", key, id, located)
for _, p := range located {
// TODO: We need to add labels
mount := newMount(p)
mounts.Put(mount)
}
}
}
return mounts, nil
}
type mountsByPath map[string]Mount
func (m mountsByPath) Slice() []Mount {
var mounts []Mount
for _, mount := range m {
mounts = append(mounts, mount)
}
return mounts
}
func (m *mountsByPath) Put(value Mount) {
key := value.Path
mount, exists := (*m)[key]
if !exists {
(*m)[key] = value
return
}
for k, v := range value.Labels {
mount.Labels[k] = v
}
(*m)[key] = mount
}
// NewMountForCapability creates a mount with the specified capability label
func NewMountForCapability(path string, capability string) Mount {
return newMount(path, capabilityLabel, capability)
}
// NewMountForVersion creates a mount with the specified version label
func NewMountForVersion(path string, version string) Mount {
return newMount(path, versionLabel, version)
}
func newMount(path string, labels ...string) Mount {
l := make(map[string]string)
for i := 0; i < len(labels)-1; i += 2 {
l[labels[i]] = labels[i+1]
}
return Mount{
Path: path,
Labels: l,
}
}

View File

@@ -0,0 +1,53 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
)
func TestMountsReturnsErrorForNoLookup(t *testing.T) {
d := mounts{}
mounts, err := d.Mounts()
require.Error(t, err)
require.Len(t, mounts, 0)
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}
func NewLocatorMockFromMap(lookupMap map[string]string) *lookup.LocatorMock {
return &lookup.LocatorMock{
LocateFunc: func(key string) ([]string, error) {
value, exists := lookupMap[key]
if !exists {
return nil, fmt.Errorf("key %v not found", key)
}
return []string{value}, nil
},
}
}

38
pkg/discover/none.go Normal file
View File

@@ -0,0 +1,38 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
// None is a null discoverer that returns an empty list of devices and
// mounts.
type None struct{}
var _ Discover = (*None)(nil)
// Devices returns an empty list of devices
func (e None) Devices() ([]Device, error) {
return []Device{}, nil
}
// Mounts returns an empty list of mounts
func (e None) Mounts() ([]Mount, error) {
return []Mount{}, nil
}
// Hooks returns an empty list of hooks
func (e None) Hooks() ([]Hook, error) {
return []Hook{}, nil
}

39
pkg/discover/none_test.go Normal file
View File

@@ -0,0 +1,39 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNone(t *testing.T) {
d := None{}
devices, err := d.Devices()
require.NoError(t, err)
require.Empty(t, devices)
mounts, err := d.Mounts()
require.NoError(t, err)
require.Empty(t, mounts)
hooks, err := d.Hooks()
require.NoError(t, err)
require.Empty(t, hooks)
}

View File

@@ -0,0 +1,71 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"fmt"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
)
type nvmlServer struct {
logger *log.Logger
composite
}
var _ Discover = (*nvmlServer)(nil)
// NewNVMLServer constructs a discoverer for server systems using NVML to discover devices
func NewNVMLServer(root string) (Discover, error) {
return NewNVMLServerWithLogger(log.StandardLogger(), root)
}
// NewNVMLServerWithLogger constructs a discoverer for server systems using NVML to discover devices with
// the specified logger
func NewNVMLServerWithLogger(logger *log.Logger, root string) (Discover, error) {
return createNVMLServer(logger, nvml.New(), root)
}
func createNVMLServer(logger *log.Logger, nvml nvml.Interface, root string) (Discover, error) {
d := nvmlServer{
logger: logger,
}
devices, err := NewNVMLDiscoverWithLogger(logger, nvml)
if err != nil {
return nil, fmt.Errorf("error constructing NVML device discoverer: %v", err)
}
libraries, err := NewLibrariesWithLogger(logger, root)
if err != nil {
return nil, fmt.Errorf("error constructing library discoverer: %v", err)
}
d.add(
// Device discovery
devices,
// Mounts discovery
libraries,
NewBinaryMountsWithLogger(logger, root),
NewIPCMountsWithLogger(logger, root),
// Hook discovery
NewHooksWithLogger(logger),
)
return &d, nil
}

View File

@@ -0,0 +1,66 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package discover
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/nvml"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/proc"
)
const (
testMajor = 999
)
func TestNVMLServerConstructor(t *testing.T) {
logger, _ := testlog.NewNullLogger()
nvml := nvml.NewMockNVMLOnLunaServer()
d, err := createNVMLServer(logger, nvml, "")
require.NoError(t, err)
instance := d.(*nvmlServer)
require.Len(t, instance.discoverers, 1+3+1)
// We need to override the nvidiaDevices member of the nvmlDiscovery
// TODO: Use a mock instead, or allow for injection into a constructor
instance.discoverers[0].(*nvmlDiscover).nvidiaDevices = &mockNvidiaDevices{}
devices, err := d.Devices()
require.NoError(t, err)
require.NotEmpty(t, devices)
_, err = d.Mounts()
require.NoError(t, err)
}
type mockNvidiaDevices struct{}
var _ proc.NvidiaDevices = (*mockNvidiaDevices)(nil)
func (d mockNvidiaDevices) Get(name string) (proc.Device, bool) {
return proc.Device{Name: name, Major: testMajor}, true
}
func (d mockNvidiaDevices) Exists(string) bool {
return false
}

129
pkg/ensure/devices.go Normal file
View File

@@ -0,0 +1,129 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package ensure
import (
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/internal/lookup"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
type ensureDevices struct {
logger *log.Logger
discover.Discover
lookup lookup.Locator
root string
}
// NewEnsureDevices creates a discoverer that wraps the specified discoverer and ensures that the
// device nodes for the discoverer are created. If a root is specified, the device nodes
// rooted there are also created.
func NewEnsureDevices(d discover.Discover, root string) discover.Discover {
return NewEnsureDevicesWithLogger(log.StandardLogger(), d, root)
}
// NewEnsureDevicesWithLogger creates a discoverer that wraps the specified discoverer and ensures that the
// device nodes for the discoverer are created. If a root is specified, the device nodes
// rooted there are also created. The specified logger is used.
func NewEnsureDevicesWithLogger(logger *log.Logger, d discover.Discover, root string) discover.Discover {
e := ensureDevices{
Discover: d,
logger: logger,
lookup: lookup.NewPathLocatorWithLogger(logger, root),
root: root,
}
return &e
}
func (d ensureDevices) Devices() ([]discover.Device, error) {
devices, err := d.Discover.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
for _, di := range devices {
for _, dn := range di.DeviceNodes {
d.deviceNode(dn)
}
}
return devices, nil
}
func (d ensureDevices) deviceNode(dn discover.DeviceNode) error {
err := d.device(dn.Path, dn.Major, dn.Minor)
if err != nil {
d.logger.Errorf("Error creating device node %+v: %v", dn, err)
}
if d.root != "" && d.root != "/" {
rootedPath := discover.DevicePath(filepath.Join(d.root, string(dn.Path)))
err = d.device(rootedPath, dn.Major, dn.Minor)
if err != nil {
d.logger.Errorf("Error creating device node %+v: %v", dn, err)
}
}
return nil
}
func (d ensureDevices) device(path discover.DevicePath, major int, minor int) error {
// TODO: We may want to check that the device node has the required permissions
_, err := os.Stat(string(path))
if err == nil {
d.logger.Infof("Device node %v already exists", path)
return nil
}
if !errors.Is(err, os.ErrNotExist) {
d.logger.Errorf("Error getting info for device node %v: %v", path, err)
return fmt.Errorf("error getting device node info: %w", err)
}
// See: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#runfile-verifications
// TODO: We should use nvidia-modprobe or a tool based off that instead
args := []string{
"-m", "666",
string(path), "c",
fmt.Sprint(major),
fmt.Sprint(minor),
}
return d.run("mknod", args...)
}
func (d ensureDevices) run(cmd string, args ...string) error {
paths, err := d.lookup.Locate(cmd)
if err != nil {
return fmt.Errorf("error finding command %v: %v", cmd, err)
}
if len(paths) == 0 {
return fmt.Errorf("command %v not found in path", cmd)
}
path := paths[0]
d.logger.Debugf("Running %v", append([]string{path}, args...))
err = exec.Command(path, args...).Run()
if err != nil {
return fmt.Errorf("error running %v: %v", append([]string{path}, args...), err)
}
return nil
}

22
pkg/ensure/ensure.go Normal file
View File

@@ -0,0 +1,22 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package ensure
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
// Ensure is an alias for Discover
type Ensure discover.Discover

41
pkg/filter/all.go Normal file
View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
type all struct {
selectors []Selector
}
// All returns a selector that evaluates true if EACH of the specified selectors
// are selected.
func All(selectors ...Selector) Selector {
s := all{
selectors: selectors,
}
return &s
}
func (s all) Selected(device discover.Device) bool {
for _, si := range s.selectors {
if !si.Selected(device) {
return false
}
}
return true
}

76
pkg/filter/all_test.go Normal file
View File

@@ -0,0 +1,76 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
func TestAll(t *testing.T) {
True := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return true
},
}
False := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return false
},
}
d := discover.Device{}
// Ensure that the mocks are set up correctly:
require.True(t, True.Selected(d))
require.False(t, False.Selected(d))
emtpy := All()
require.True(t, emtpy.Selected(d))
s00 := All(False, False)
require.False(t, s00.Selected(d))
s01 := All(False, True)
require.False(t, s01.Selected(d))
s10 := All(True, False)
require.False(t, s10.Selected(d))
s11 := All(True, True)
require.True(t, s11.Selected(d))
}
type discoverMock struct {
discover.None
devices []discover.Device
devicesError error
mounts []discover.Mount
mountsError error
}
var _ discover.Discover = (*discoverMock)(nil)
func (m discoverMock) Devices() ([]discover.Device, error) {
return m.devices, m.devicesError
}
func (m discoverMock) Mounts() ([]discover.Mount, error) {
return m.mounts, m.mountsError
}

41
pkg/filter/any.go Normal file
View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
type any struct {
selectors []Selector
}
// Any returns a selector that evaluates true if ANY of the specified selectors
// are selected
func Any(selectors ...Selector) Selector {
s := any{
selectors: selectors,
}
return &s
}
func (s any) Selected(device discover.Device) bool {
for _, si := range s.selectors {
if si.Selected(device) {
return true
}
}
return false
}

58
pkg/filter/any_test.go Normal file
View File

@@ -0,0 +1,58 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. Any rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
func TestAny(t *testing.T) {
True := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return true
},
}
False := &SelectorMock{
SelectedFunc: func(discover.Device) bool {
return false
},
}
d := discover.Device{}
// Ensure that the mocks are set up correctly:
require.True(t, True.Selected(d))
require.False(t, False.Selected(d))
emtpy := Any()
require.False(t, emtpy.Selected(d))
s00 := Any(False, False)
require.False(t, s00.Selected(d))
s01 := Any(False, True)
require.True(t, s01.Selected(d))
s10 := Any(True, False)
require.True(t, s10.Selected(d))
s11 := Any(True, True)
require.True(t, s11.Selected(d))
}

60
pkg/filter/by_id.go Normal file
View File

@@ -0,0 +1,60 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
type devicesByID map[string]struct{}
var _ Selector = (*devicesByID)(nil)
// NewDeviceSelector creates a selector for devices based on the specified IDs.
func NewDeviceSelector(ids ...string) Selector {
deviceIDs := make(devicesByID)
for _, id := range ids {
deviceIDs[id] = struct{}{}
}
return deviceIDs
}
// Selected checks whether a specific device is included in the set of devicesIDs
// The device is checked by UUID, Index, and PCIBusID and if any of these match
// the device is considered selected.
func (d devicesByID) Selected(device discover.Device) bool {
var exists bool
_, exists = d[device.UUID]
if exists {
return true
}
_, exists = d[device.Index]
if exists {
return true
}
_, exists = d[device.PCIBusID.String()]
if exists {
return true
}
return false
}

47
pkg/filter/by_id_test.go Normal file
View File

@@ -0,0 +1,47 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
func TestDeviceByID(t *testing.T) {
device := discover.Device{
Index: "index",
UUID: "uuid",
PCIBusID: discover.PCIBusID("pcibusid"),
}
require.False(t, NewDeviceSelector().Selected(device))
require.False(t, NewDeviceSelector("notindex", "notuuid", "notpcibusid").Selected(device))
require.True(t, NewDeviceSelector("index").Selected(device))
require.True(t, NewDeviceSelector("notindex", "index").Selected(device))
require.True(t, NewDeviceSelector("uuid").Selected(device))
require.True(t, NewDeviceSelector("notuuid", "uuid").Selected(device))
require.True(t, NewDeviceSelector("pcibusid").Selected(device))
require.True(t, NewDeviceSelector("notpcibusid", "pcibusid").Selected(device))
require.True(t, NewDeviceSelector("index", "uuid", "pcibusid").Selected(device))
}

90
pkg/filter/devices.go Normal file
View File

@@ -0,0 +1,90 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
const (
visibleDevicesAll = "all"
visibleDevicesNone = "none"
visibleDevicesVoid = "void"
)
type devices struct {
discover.Discover
logger *log.Logger
selector Selector
}
var _ discover.Discover = (*devices)(nil)
// NewSelectDevicesFrom creates a filter that selects devices based on the value of the
// visible devices string.
func NewSelectDevicesFrom(d discover.Discover, visibleDevices string, env EnvLookup) discover.Discover {
return NewSelectDevicesFromWithLogger(log.StandardLogger(), d, visibleDevices, env)
}
// NewSelectDevicesFromWithLogger creates a filter as for NewSelectDevicesFrom with the
// specified logger.
func NewSelectDevicesFromWithLogger(logger *log.Logger, d discover.Discover, visibleDevices string, env EnvLookup) discover.Discover {
if visibleDevices == "" || visibleDevices == visibleDevicesNone || visibleDevices == visibleDevicesVoid {
return &discover.None{}
}
var visibleDeviceSelector Selector
if visibleDevices == visibleDevicesAll {
visibleDeviceSelector = StandardDevice()
} else {
visibleDeviceSelector = All(StandardDevice(), NewDeviceSelector(strings.Split(visibleDevices, ",")...))
}
controlDeviceIds := getControlDeviceIDsFromEnvWithLogger(logger, env)
controlDeviceSelector := All(ControlDevice(), NewDeviceSelector(controlDeviceIds...))
vd := devices{
Discover: d,
logger: logger,
selector: Any(visibleDeviceSelector, controlDeviceSelector),
}
return &vd
}
// Devices returns the list of selected devices after filtering based on the
// configured selector
func (d devices) Devices() ([]discover.Device, error) {
devices, err := d.Discover.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
var selected []discover.Device
for _, di := range devices {
if d.selector.Selected(di) {
d.logger.Infof("selecting device=%v", di)
selected = append(selected, di)
}
}
return selected, nil
}

104
pkg/filter/devices_test.go Normal file
View File

@@ -0,0 +1,104 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
log "github.com/sirupsen/logrus"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
func TestConstructor(t *testing.T) {
logger, _ := testlog.NewNullLogger()
device0 := discover.Device{
Index: "0",
UUID: "0",
PCIBusID: discover.PCIBusID("0"),
}
device1 := discover.Device{
Index: "1",
UUID: "1",
PCIBusID: discover.PCIBusID("1"),
}
device2 := discover.Device{
Index: "2",
UUID: "2",
PCIBusID: discover.PCIBusID("2"),
}
device3 := discover.Device{
Index: "3",
UUID: "3",
PCIBusID: discover.PCIBusID("3"),
}
controlDevice := discover.Device{
UUID: "CONTROL",
}
mockDevices := []discover.Device{
device0,
device1,
device2,
device3,
controlDevice,
}
d := discoverMock{
devices: mockDevices,
}
var ok bool
withDefaultLogger, ok := NewSelectDevicesFrom(d, "all", nil).(*devices)
require.True(t, ok)
require.Same(t, log.StandardLogger(), withDefaultLogger.logger)
_, ok = NewSelectDevicesFromWithLogger(logger, d, "", nil).(*discover.None)
require.True(t, ok)
_, ok = NewSelectDevicesFromWithLogger(logger, d, "void", nil).(*discover.None)
require.True(t, ok)
_, ok = NewSelectDevicesFromWithLogger(logger, d, "none", nil).(*discover.None)
require.True(t, ok)
all, ok := NewSelectDevicesFromWithLogger(logger, d, "all", nil).(*devices)
require.True(t, ok)
devs, err := all.Devices()
require.NoError(t, err)
require.ElementsMatch(t, mockDevices, devs)
f, ok := NewSelectDevicesFromWithLogger(logger, d, "0", nil).(*devices)
require.True(t, ok)
devs, err = f.Devices()
require.NoError(t, err)
require.Len(t, devs, 2)
require.ElementsMatch(t, devs, []discover.Device{device0, controlDevice})
f, ok = NewSelectDevicesFromWithLogger(logger, d, "0,2", nil).(*devices)
require.True(t, ok)
devs, err = f.Devices()
require.NoError(t, err)
require.Len(t, devs, 3)
require.ElementsMatch(t, devs, []discover.Device{device0, device2, controlDevice})
}

26
pkg/filter/filter.go Normal file
View File

@@ -0,0 +1,26 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
//go:generate moq -stub -out filter_mock.go . EnvLookup
// EnvLookup defines an interface that supports the LookupEnv function for getting
// environment variable values.
// TODO: This belongs in a different package
type EnvLookup interface {
LookupEnv(string) (string, bool)
}

77
pkg/filter/filter_mock.go Normal file
View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package filter
import (
"sync"
)
// Ensure, that EnvLookupMock does implement EnvLookup.
// If this is not the case, regenerate this file with moq.
var _ EnvLookup = &EnvLookupMock{}
// EnvLookupMock is a mock implementation of EnvLookup.
//
// func TestSomethingThatUsesEnvLookup(t *testing.T) {
//
// // make and configure a mocked EnvLookup
// mockedEnvLookup := &EnvLookupMock{
// LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method")
// },
// }
//
// // use mockedEnvLookup in code that requires EnvLookup
// // and then make assertions.
//
// }
type EnvLookupMock struct {
// LookupEnvFunc mocks the LookupEnv method.
LookupEnvFunc func(s string) (string, bool)
// calls tracks calls to the methods.
calls struct {
// LookupEnv holds details about calls to the LookupEnv method.
LookupEnv []struct {
// S is the s argument value.
S string
}
}
lockLookupEnv sync.RWMutex
}
// LookupEnv calls LookupEnvFunc.
func (mock *EnvLookupMock) LookupEnv(s string) (string, bool) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLookupEnv.Lock()
mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo)
mock.lockLookupEnv.Unlock()
if mock.LookupEnvFunc == nil {
var (
sOut string
bOut bool
)
return sOut, bOut
}
return mock.LookupEnvFunc(s)
}
// LookupEnvCalls gets all the calls that were made to LookupEnv.
// Check the length with:
// len(mockedEnvLookup.LookupEnvCalls())
func (mock *EnvLookupMock) LookupEnvCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLookupEnv.RLock()
calls = mock.calls.LookupEnv
mock.lockLookupEnv.RUnlock()
return calls
}

View File

@@ -0,0 +1,107 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
const (
devicesAll = "all"
)
type controlDevices struct {
discover.Discover
logger *log.Logger
selector Selector
}
var _ discover.Discover = (*controlDevices)(nil)
// NewControlDevicesFrom creates a filter that selects devices based on the value of the
// visible devices string.
func NewControlDevicesFrom(d discover.Discover, env EnvLookup) Selector {
return NewControlDevicesFromWithLogger(log.StandardLogger(), d, env)
}
// NewControlDevicesFromWithLogger creates a filter as for NewControlDevicesFrom with the
// specified logger.
func NewControlDevicesFromWithLogger(logger *log.Logger, d discover.Discover, env EnvLookup) Selector {
controlDevices := getControlDeviceIDsFromEnvWithLogger(logger, env)
return NewDeviceSelector(controlDevices...)
}
type controlDevice struct{}
// ControlDevice returns a selector for control devices
func ControlDevice() Selector {
return &controlDevice{}
}
// Selected returns true for a controll devices and false for standard devices. A control device
// has an empty index and PCI bus ID and a non-empty UUID.
func (s controlDevice) Selected(device discover.Device) bool {
if device.Index != "" {
return false
}
if device.PCIBusID != "" {
return false
}
if device.UUID == "" {
return false
}
return true
}
func getControlDeviceIDsFromEnvWithLogger(logger *log.Logger, env EnvLookup) []string {
controlDevices := []string{discover.ControlDeviceUUID}
migControlDevices := getMIGControlDevicesFromEnvWithLogger(logger, env)
return append(controlDevices, migControlDevices...)
}
func getMIGControlDevicesFromEnvWithLogger(logger *log.Logger, env EnvLookup) []string {
if env == nil {
logger.Debugf("Environment not specified; no MIG Control devices selected")
return []string{}
}
var controlDevices []string
// Add MIG control devices
migEnvUUIDMap := map[string]string{
discover.MIGConfigDeviceUUID: "NVIDIA_MIG_CONFIG_DEVICES",
discover.MIGMonitorDeviceUUID: "NVIDIA_MIG_MONITOR_DEVICES",
}
for uuid, migEnv := range migEnvUUIDMap {
config, exists := env.LookupEnv(migEnv)
if !exists {
logger.Debugf("Envvar %v not set", migEnv)
continue
}
if config == devicesAll {
logger.Infof("Found %v=%v; selecting MIG %v devices", migEnv, config, uuid)
controlDevices = append(controlDevices, uuid)
} else {
logger.Debugf("Found %v=%v; Skipping MIG %v devices (%v != %v)", migEnv, config, uuid, config, devicesAll)
}
}
return controlDevices
}

View File

@@ -0,0 +1,123 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"fmt"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
func TestControlDevice(t *testing.T) {
control := ControlDevice()
pcibusID := discover.PCIBusID("pcibusid")
device := discover.Device{
Index: "index",
UUID: "uuid",
PCIBusID: pcibusID,
}
require.False(t, control.Selected(device))
require.False(t, control.Selected(
discover.Device{UUID: "uuid", PCIBusID: pcibusID},
))
require.False(t, control.Selected(
discover.Device{Index: "index", PCIBusID: pcibusID},
))
require.False(t, control.Selected(
discover.Device{Index: "index", UUID: "uuid"},
))
require.False(t, control.Selected(discover.Device{}))
require.True(t, control.Selected(discover.Device{UUID: "uuid"}))
}
func TestGetControlDevicesFromEnv(t *testing.T) {
testCases := []struct {
name string
env map[string]string
expectedIds []string
}{
{
name: "nil environment",
env: nil,
expectedIds: []string{"CONTROL"},
},
{
name: "empty environment",
env: map[string]string{},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG monitor blank",
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": ""},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG config blank",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": ""},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG monitor not all",
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": "not-all"},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG config not all",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "not-all"},
expectedIds: []string{"CONTROL"},
},
{
name: "MIG monitor all",
env: map[string]string{"NVIDIA_MIG_MONITOR_DEVICES": "all"},
expectedIds: []string{"CONTROL", "MONITOR"},
},
{
name: "MIG config all",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "all"},
expectedIds: []string{"CONTROL", "CONFIG"},
},
{
name: "MIG config and monitor all",
env: map[string]string{"NVIDIA_MIG_CONFIG_DEVICES": "all", "NVIDIA_MIG_MONITOR_DEVICES": "all"},
expectedIds: []string{"CONTROL", "CONFIG", "MONITOR"},
},
}
for i, tc := range testCases {
logger, _ := testlog.NewNullLogger()
t.Run(fmt.Sprintf("%d: %s", i, tc.name), func(t *testing.T) {
deviceIDs := getControlDeviceIDsFromEnvWithLogger(logger, &EnvLookupMock{
LookupEnvFunc: func(s string) (string, bool) {
value, exists := tc.env[s]
return value, exists
},
})
require.ElementsMatch(t, tc.expectedIds, deviceIDs)
})
}
}

View File

@@ -0,0 +1,41 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
type standardDevice struct{}
// StandardDevice returns a selector for regular (non-control) devices
func StandardDevice() Selector {
return &standardDevice{}
}
// Selected returns true for a standard device and false for controll devices. A regular device
// is expected to have an index, uuid, and PCI bus ID.
func (s standardDevice) Selected(device discover.Device) bool {
if device.Index == "" {
return false
}
if device.PCIBusID == "" {
return false
}
if device.UUID == "" {
return false
}
return true
}

View File

@@ -0,0 +1,52 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import (
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
)
func TestStandardDevice(t *testing.T) {
standard := StandardDevice()
pcibusID := discover.PCIBusID("pcibusid")
device := discover.Device{
Index: "index",
UUID: "uuid",
PCIBusID: pcibusID,
}
require.True(t, standard.Selected(device))
require.False(t, standard.Selected(
discover.Device{UUID: "uuid", PCIBusID: pcibusID},
))
require.False(t, standard.Selected(
discover.Device{Index: "index", PCIBusID: pcibusID},
))
require.False(t, standard.Selected(
discover.Device{Index: "index", UUID: "uuid"},
))
require.False(t, standard.Selected(discover.Device{}))
}

27
pkg/filter/selector.go Normal file
View File

@@ -0,0 +1,27 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package filter
import "gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
//go:generate moq -stub -out selector_mock.go . Selector
// Selector defines an interface for determining whether a specfied Device is selected
// by a particular configuration.
type Selector interface {
Selected(discover.Device) bool
}

View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package filter
import (
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
"sync"
)
// Ensure, that SelectorMock does implement Selector.
// If this is not the case, regenerate this file with moq.
var _ Selector = &SelectorMock{}
// SelectorMock is a mock implementation of Selector.
//
// func TestSomethingThatUsesSelector(t *testing.T) {
//
// // make and configure a mocked Selector
// mockedSelector := &SelectorMock{
// SelectedFunc: func(device discover.Device) bool {
// panic("mock out the Selected method")
// },
// }
//
// // use mockedSelector in code that requires Selector
// // and then make assertions.
//
// }
type SelectorMock struct {
// SelectedFunc mocks the Selected method.
SelectedFunc func(device discover.Device) bool
// calls tracks calls to the methods.
calls struct {
// Selected holds details about calls to the Selected method.
Selected []struct {
// Device is the device argument value.
Device discover.Device
}
}
lockSelected sync.RWMutex
}
// Selected calls SelectedFunc.
func (mock *SelectorMock) Selected(device discover.Device) bool {
callInfo := struct {
Device discover.Device
}{
Device: device,
}
mock.lockSelected.Lock()
mock.calls.Selected = append(mock.calls.Selected, callInfo)
mock.lockSelected.Unlock()
if mock.SelectedFunc == nil {
var (
bOut bool
)
return bOut
}
return mock.SelectedFunc(device)
}
// SelectedCalls gets all the calls that were made to Selected.
// Check the length with:
// len(mockedSelector.SelectedCalls())
func (mock *SelectorMock) SelectedCalls() []struct {
Device discover.Device
} {
var calls []struct {
Device discover.Device
}
mock.lockSelected.RLock()
calls = mock.calls.Selected
mock.lockSelected.RUnlock()
return calls
}

118
pkg/modify/device.go Normal file
View File

@@ -0,0 +1,118 @@
package modify
import (
"fmt"
"os"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
)
// Device is an alias to discover.Device that allows for addition of a Modify method
type Device struct {
logger *log.Logger
discover.Device
}
// ProcMount is an alias to discover.Mount that allows for the addition of a Modify method for
// proc paths associated with devices
type ProcMount struct {
logger *log.Logger
discover.ProcPath
}
var _ Modifier = (*Device)(nil)
var _ Modifier = (*ProcMount)(nil)
// Modify applies the modifications required by a Device to the specified OCI specification
func (d Device) Modify(spec oci.Spec) error {
for _, dn := range d.DeviceNodes {
mi := deviceNode{
logger: d.logger,
DeviceNode: dn,
}
err := mi.Modify(spec)
if err != nil {
return fmt.Errorf("could not inject device node %v: %v", dn, err)
}
}
for _, p := range d.ProcPaths {
mi := ProcMount{
logger: d.logger,
ProcPath: p,
}
err := mi.Modify(spec)
if err != nil {
return fmt.Errorf("could not inject proc path %v: %v", p, err)
}
}
return nil
}
type deviceNode struct {
logger *log.Logger
discover.DeviceNode
}
func (d deviceNode) Modify(spec oci.Spec) error {
return spec.Modify(d.specModifier)
}
func (d deviceNode) specModifier(spec *specs.Spec) error {
if spec.Linux == nil {
d.logger.Debugf("Initializing spec.Linux")
spec.Linux = &specs.Linux{}
}
if spec.Linux.Resources == nil {
d.logger.Debugf("Initializing spec.LinuxResources")
spec.Linux.Resources = &specs.LinuxResources{}
}
// TODO: These need to be configurable
deviceFileMode := os.FileMode(8630)
deviceUID := uint32(0)
deviceGID := uint32(0)
deviceMajor := int64(d.Major)
deviceMinor := int64(d.Minor)
d.logger.Infof("Adding device %v", d.Path)
ociDevice := specs.LinuxDevice{
Path: string(d.Path),
Type: "c",
Major: deviceMajor,
Minor: deviceMinor,
FileMode: &deviceFileMode,
UID: &deviceUID,
GID: &deviceGID,
}
spec.Linux.Devices = append(spec.Linux.Devices, ociDevice)
ociDeviceCgroup := specs.LinuxDeviceCgroup{
Allow: true,
Type: "c",
Major: &deviceMajor,
Minor: &deviceMinor,
Access: "rwm",
}
// TODO: We have to handle the case where we are updating the cgroups for multiple devices
// leading to duplicates in the spec
spec.Linux.Resources.Devices = append(spec.Linux.Resources.Devices, ociDeviceCgroup)
return nil
}
// Modify applies the modifications required for a Mount to the specified OCI specification
func (m ProcMount) Modify(spec oci.Spec) error {
return spec.Modify(m.specModifier)
}
func (m ProcMount) specModifier(spec *specs.Spec) error {
m.logger.Infof("Mounting read-only proc path %v", m.ProcPath)
spec.Linux.ReadonlyPaths = append(spec.Linux.ReadonlyPaths, string(m.ProcPath))
return nil
}

155
pkg/modify/discover.go Normal file
View File

@@ -0,0 +1,155 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package modify
import (
"fmt"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
)
type discoverModifier struct {
logger *log.Logger
discover discover.Discover
root string
bundleDir string
}
var _ Modifier = (*discoverModifier)(nil)
// NewModifierFor creates a Modifier that can be used to apply the modifications to an OCI specification
// required by the specified Discover instance.
func NewModifierFor(discover discover.Discover, root string, bundleDir string) Modifier {
return NewModifierWithLoggerFor(log.StandardLogger(), discover, root, bundleDir)
}
// NewModifierWithLoggerFor creates a Modifier that can be used to apply the modifications to an OCI specification
// required by the specified Discover instance.
func NewModifierWithLoggerFor(logger *log.Logger, discover discover.Discover, root string, bundleDir string) Modifier {
m := discoverModifier{
logger: logger,
discover: discover,
root: root,
bundleDir: bundleDir,
}
return &m
}
// Modify applies the modifications for the discovered devices, mounts, etc. to the specified
// OCI spec.
func (m discoverModifier) Modify(spec oci.Spec) error {
m.logger.Infof("Determining required OCI spec modifications")
modifiers, err := m.modifiers()
if err != nil {
return fmt.Errorf("error constructing modifiers: %v", err)
}
m.logger.Infof("Applying %v modifications", len(modifiers))
for _, mi := range modifiers {
err := mi.Modify(spec)
if err != nil {
return fmt.Errorf("could not apply modifier %v: %v", mi, err)
}
}
return nil
}
func (m discoverModifier) modifiers() ([]Modifier, error) {
var modifiers []Modifier
deviceModifiers, err := m.deviceModifiers()
if err != nil {
return nil, err
}
modifiers = append(modifiers, deviceModifiers...)
mountModifiers, err := m.mountModifiers()
if err != nil {
return nil, err
}
modifiers = append(modifiers, mountModifiers...)
hookModifiers, err := m.hookModifiers()
if err != nil {
return nil, err
}
modifiers = append(modifiers, hookModifiers...)
return modifiers, nil
}
func (m discoverModifier) deviceModifiers() ([]Modifier, error) {
var modifiers []Modifier
devices, err := m.discover.Devices()
if err != nil {
return nil, fmt.Errorf("error discovering devices: %v", err)
}
for _, d := range devices {
m := Device{
logger: m.logger,
Device: d,
}
modifiers = append(modifiers, m)
}
return modifiers, nil
}
func (m discoverModifier) mountModifiers() ([]Modifier, error) {
var modifiers []Modifier
mounts, err := m.discover.Mounts()
if err != nil {
return nil, fmt.Errorf("error discovering mounts: %v", err)
}
for _, mi := range mounts {
mm := Mount{
logger: m.logger,
Mount: mi,
root: m.root,
}
modifiers = append(modifiers, mm)
}
return modifiers, nil
}
func (m discoverModifier) hookModifiers() ([]Modifier, error) {
var modifiers []Modifier
hooks, err := m.discover.Hooks()
if err != nil {
return nil, fmt.Errorf("error discovering hooks: %v", err)
}
for _, h := range hooks {
m := Hook{
logger: m.logger,
Hook: h,
bundleDir: m.bundleDir,
}
modifiers = append(modifiers, m)
}
return modifiers, nil
}

81
pkg/modify/hooks.go Normal file
View File

@@ -0,0 +1,81 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package modify
import (
"fmt"
"path/filepath"
"strings"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
)
// Hook is an alias to discover.Hook that allows for addition of a Modify method
type Hook struct {
logger *log.Logger
discover.Hook
bundleDir string
}
var _ Modifier = (*Hook)(nil)
// Modify applies the modifications required by a Hook to the specified OCI specification
func (h Hook) Modify(spec oci.Spec) error {
return spec.Modify(h.specModifier)
}
func (h Hook) specModifier(spec *specs.Spec) error {
if spec.Hooks == nil {
h.logger.Debugf("Initializing spec.Hooks")
spec.Hooks = &specs.Hooks{}
}
// TODO: This is duplicated in the hook specification
const rootPattern = "@Root.Path@"
rootPath := spec.Root.Path
if !filepath.IsAbs(rootPath) {
rootPath = filepath.Join(h.bundleDir, rootPath)
}
var args []string
for _, a := range h.Args {
if strings.Contains(a, rootPattern) {
args = append(args, strings.ReplaceAll(a, rootPattern, rootPath))
continue
}
args = append(args, a)
}
specHook := specs.Hook{
Path: h.Path,
Args: args,
}
h.logger.Infof("Adding %v hook %+v", h.HookName, specHook)
switch h.HookName {
case "create-container":
spec.Hooks.CreateContainer = append(spec.Hooks.CreateContainer, specHook)
default:
return fmt.Errorf("unexpected hook name: %v", h.HookName)
}
return nil
}

28
pkg/modify/modify.go Normal file
View File

@@ -0,0 +1,28 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package modify
import (
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
)
//go:generate moq -stub -out modify_mock.go . Modifier
// Modifier defines an interface for modifying an OCI Specification.
type Modifier interface {
Modify(oci.Spec) error
}

77
pkg/modify/modify_mock.go Normal file
View File

@@ -0,0 +1,77 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package modify
import (
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
"sync"
)
// Ensure, that ModifierMock does implement Modifier.
// If this is not the case, regenerate this file with moq.
var _ Modifier = &ModifierMock{}
// ModifierMock is a mock implementation of Modifier.
//
// func TestSomethingThatUsesModifier(t *testing.T) {
//
// // make and configure a mocked Modifier
// mockedModifier := &ModifierMock{
// ModifyFunc: func(spec oci.Spec) error {
// panic("mock out the Modify method")
// },
// }
//
// // use mockedModifier in code that requires Modifier
// // and then make assertions.
//
// }
type ModifierMock struct {
// ModifyFunc mocks the Modify method.
ModifyFunc func(spec oci.Spec) error
// calls tracks calls to the methods.
calls struct {
// Modify holds details about calls to the Modify method.
Modify []struct {
// Spec is the spec argument value.
Spec oci.Spec
}
}
lockModify sync.RWMutex
}
// Modify calls ModifyFunc.
func (mock *ModifierMock) Modify(spec oci.Spec) error {
callInfo := struct {
Spec oci.Spec
}{
Spec: spec,
}
mock.lockModify.Lock()
mock.calls.Modify = append(mock.calls.Modify, callInfo)
mock.lockModify.Unlock()
if mock.ModifyFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ModifyFunc(spec)
}
// ModifyCalls gets all the calls that were made to Modify.
// Check the length with:
// len(mockedModifier.ModifyCalls())
func (mock *ModifierMock) ModifyCalls() []struct {
Spec oci.Spec
} {
var calls []struct {
Spec oci.Spec
}
mock.lockModify.RLock()
calls = mock.calls.Modify
mock.lockModify.RUnlock()
return calls
}

53
pkg/modify/mounts.go Normal file
View File

@@ -0,0 +1,53 @@
package modify
import (
"strings"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/discover"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
)
// Mount is an alias to discover.Mount that allows for addition of a Modify method
type Mount struct {
logger *log.Logger
discover.Mount
root string
}
var _ Modifier = (*Mount)(nil)
// Modify applies the modifications required for a Mount to the specified OCI specification
func (d Mount) Modify(spec oci.Spec) error {
return spec.Modify(d.specModifier)
}
// TODO: We need to ensure that we are correctly mounting the proc paths
// Also — Im not sure how this is done, but we will need a new tempfs mounted at /proc/driver/nvidia/ underneath which all of these other mounted directories get put
// Maybe this?
// https://github.com/opencontainers/runtime-spec/blob/master/specs-go/config.go#L175 (edited)
// specs-go/config.go:175
// MaskedPaths []string `json:"maskedPaths,omitempty"`
// <https://github.com/opencontainers/runtime-spec|opencontainers/runtime-spec>opencontainers/runtime-spec | Added by GitHub
// 13:53
// Proabably, given…
// https://github.com/opencontainers/runtime-spec/blob/master/config-linux.md#masked-paths (edited)
// TODO: We can try masking all of /proc/driver/nvidia and then mounting the paths read-only
func (d Mount) specModifier(spec *specs.Spec) error {
source := d.Path
destination := strings.TrimPrefix(d.Path, d.root)
d.logger.Infof("Mounting %v -> %v", source, destination)
mount := specs.Mount{
Destination: destination,
Source: source,
Type: "bind",
Options: []string{
"rbind",
"rprivate",
},
}
spec.Mounts = append(spec.Mounts, mount)
return nil
}

135
pkg/oci/args.go Normal file
View File

@@ -0,0 +1,135 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os"
"path/filepath"
"strings"
)
const (
specFileName = "config.json"
)
// GetBundleDir returns the bundle directory or default depending on the
// supplied command line arguments.
func GetBundleDir(args []string) (string, error) {
bundleDir, err := GetBundleDirFromArgs(args)
if err != nil {
return "", fmt.Errorf("error getting bundle dir from args: %v", err)
}
if bundleDir != "" {
return bundleDir, nil
}
defaultBundleDir, err := GetDefaultBundleDir()
if err != nil {
return "", fmt.Errorf("error getting default bundle dir: %v", err)
}
return defaultBundleDir, nil
}
// GetBundleDirFromArgs checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
// The following are supported:
// --bundle{{SEP}}BUNDLE_PATH
// -bundle{{SEP}}BUNDLE_PATH
// -b{{SEP}}BUNDLE_PATH
// where {{SEP}} is either ' ' or '='
func GetBundleDirFromArgs(args []string) (string, error) {
var bundleDir string
for i := 0; i < len(args); i++ {
param := args[i]
parts := strings.SplitN(param, "=", 2)
if !IsBundleFlag(parts[0]) {
continue
}
// The flag has the format --bundle=/path
if len(parts) == 2 {
bundleDir = parts[1]
continue
}
// The flag has the format --bundle /path
if i+1 < len(args) {
bundleDir = args[i+1]
i++
continue
}
// --bundle / -b was the last element of args
return "", fmt.Errorf("bundle option requires an argument")
}
return bundleDir, nil
}
// GetDefaultBundleDir returns the bundle directory that is to be used if no alternative is
// specified via the command line, for example.
func GetDefaultBundleDir() (string, error) {
workingDirectory, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("error getting working directory: %v", err)
}
return workingDirectory, nil
}
// GetSpecFilePath returns the expected path to the OCI specification file for the given
// bundle directory.
func GetSpecFilePath(bundleDir string) string {
specFilePath := filepath.Join(bundleDir, specFileName)
return specFilePath
}
// IsBundleFlag is a helper function that checks wither the specified argument represents
// a bundle flag (--bundle or -b)
func IsBundleFlag(arg string) bool {
if !strings.HasPrefix(arg, "-") {
return false
}
trimmed := strings.TrimLeft(arg, "-")
return trimmed == "b" || trimmed == "bundle"
}
// HasCreateSubcommand checks the supplied arguments for a 'create' subcommand
func HasCreateSubcommand(args []string) bool {
var previousWasBundle bool
for _, a := range args {
// We check for '--bundle create' explicitly to ensure that we
// don't inadvertently trigger a modification if the bundle directory
// is specified as `create`
if !previousWasBundle && IsBundleFlag(a) {
previousWasBundle = true
continue
}
if !previousWasBundle && a == "create" {
return true
}
previousWasBundle = false
}
return false
}

198
pkg/oci/args_test.go Normal file
View File

@@ -0,0 +1,198 @@
package oci
import (
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetBundleDir(t *testing.T) {
defaultBundleDir, err := os.Getwd()
require.NoError(t, err)
type expected struct {
bundle string
isError bool
}
testCases := []struct {
argv []string
expected expected
}{
{
argv: []string{},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"create"},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"--bundle"},
expected: expected{
isError: true,
},
},
{
argv: []string{"-b"},
expected: expected{
isError: true,
},
},
{
argv: []string{"--bundle", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"--not-bundle", "/foo/bar"},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"--"},
expected: expected{
bundle: defaultBundleDir,
},
},
{
argv: []string{"-bundle", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"--bundle=/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"-b=/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"-b=/foo/=bar"},
expected: expected{
bundle: "/foo/=bar",
},
},
{
argv: []string{"-b", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"create", "-b", "/foo/bar"},
expected: expected{
bundle: "/foo/bar",
},
},
{
argv: []string{"-b", "create", "create"},
expected: expected{
bundle: "create",
},
},
{
argv: []string{"-b=create", "create"},
expected: expected{
bundle: "create",
},
},
{
argv: []string{"-b", "create"},
expected: expected{
bundle: "create",
},
},
}
for i, tc := range testCases {
bundle, err := GetBundleDir(tc.argv)
if tc.expected.isError {
require.Errorf(t, err, "%d: %v", i, tc)
} else {
require.NoErrorf(t, err, "%d: %v", i, tc)
}
require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc)
}
}
func TestGetDefaultBundleDir(t *testing.T) {
defaultBundleDir, err := os.Getwd()
require.NoError(t, err)
bundleDir, err := GetDefaultBundleDir()
require.NoError(t, err)
require.Equal(t, defaultBundleDir, bundleDir)
}
func TestGetSpecFilePathAppendsFilename(t *testing.T) {
testCases := []struct {
bundleDir string
expected string
}{
{
bundleDir: "",
expected: "config.json",
},
{
bundleDir: "/not/empty/",
expected: "/not/empty/config.json",
},
{
bundleDir: "not/absolute",
expected: "not/absolute/config.json",
},
}
for i, tc := range testCases {
specPath := GetSpecFilePath(tc.bundleDir)
require.Equalf(t, tc.expected, specPath, "%d: %v", i, tc)
}
}
func TestHasCreateSubcommand(t *testing.T) {
testCases := []struct {
args []string
shouldModify bool
}{
{
shouldModify: false,
},
{
args: []string{"create"},
shouldModify: true,
},
{
args: []string{"--bundle=create"},
shouldModify: false,
},
{
args: []string{"--bundle", "create"},
shouldModify: false,
},
{
args: []string{"create"},
shouldModify: true,
},
}
for i, tc := range testCases {
require.Equal(t, tc.shouldModify, HasCreateSubcommand(tc.args), "%d: %v", i, tc)
}
}

25
pkg/oci/runtime.go Normal file
View File

@@ -0,0 +1,25 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
//go:generate moq -stub -out runtime_mock.go . Runtime
// Runtime is an interface for a runtime shim. The Exec method accepts a list
// of command line arguments, and returns an error / nil.
type Runtime interface {
Exec([]string) error
}

View File

@@ -0,0 +1,61 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
// The executable specified is taken from the list of supplied candidates, with the first match
// present in the PATH being selected.
func NewLowLevelRuntime(candidates ...string) (Runtime, error) {
return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...)
}
// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger.
func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) {
runtimePath, err := findRuntime(candidates)
if err != nil {
return nil, fmt.Errorf("error locating runtime: %v", err)
}
return NewRuntimeForPathWithLogger(logger, runtimePath)
}
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.
// The absolute path to the first match is returned.
func findRuntime(candidates []string) (string, error) {
if len(candidates) == 0 {
return "", fmt.Errorf("at least one runtime candidate must be specified")
}
for _, candidate := range candidates {
log.Infof("Looking for runtime binary '%v'", candidate)
runcPath, err := exec.LookPath(candidate)
if err == nil {
log.Infof("Found runtime binary '%v'", runcPath)
return runcPath, nil
}
log.Warnf("Runtime binary '%v' not found: %v", candidate, err)
}
return "", fmt.Errorf("no runtime binary found from candidate list: %v", candidates)
}

76
pkg/oci/runtime_mock.go Normal file
View File

@@ -0,0 +1,76 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package oci
import (
"sync"
)
// Ensure, that RuntimeMock does implement Runtime.
// If this is not the case, regenerate this file with moq.
var _ Runtime = &RuntimeMock{}
// RuntimeMock is a mock implementation of Runtime.
//
// func TestSomethingThatUsesRuntime(t *testing.T) {
//
// // make and configure a mocked Runtime
// mockedRuntime := &RuntimeMock{
// ExecFunc: func(strings []string) error {
// panic("mock out the Exec method")
// },
// }
//
// // use mockedRuntime in code that requires Runtime
// // and then make assertions.
//
// }
type RuntimeMock struct {
// ExecFunc mocks the Exec method.
ExecFunc func(strings []string) error
// calls tracks calls to the methods.
calls struct {
// Exec holds details about calls to the Exec method.
Exec []struct {
// Strings is the strings argument value.
Strings []string
}
}
lockExec sync.RWMutex
}
// Exec calls ExecFunc.
func (mock *RuntimeMock) Exec(strings []string) error {
callInfo := struct {
Strings []string
}{
Strings: strings,
}
mock.lockExec.Lock()
mock.calls.Exec = append(mock.calls.Exec, callInfo)
mock.lockExec.Unlock()
if mock.ExecFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ExecFunc(strings)
}
// ExecCalls gets all the calls that were made to Exec.
// Check the length with:
// len(mockedRuntime.ExecCalls())
func (mock *RuntimeMock) ExecCalls() []struct {
Strings []string
} {
var calls []struct {
Strings []string
}
mock.lockExec.RLock()
calls = mock.calls.Exec
mock.lockExec.RUnlock()
return calls
}

70
pkg/oci/runtime_path.go Normal file
View File

@@ -0,0 +1,70 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
)
// pathRuntime wraps the path that a binary and defines the semanitcs for how to exec into it.
// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the
// Runtime internface.
type pathRuntime struct {
logger *log.Logger
path string
execRuntime Runtime
}
var _ Runtime = (*pathRuntime)(nil)
// NewRuntimeForPath creates a Runtime for the specified path with the standard logger
func NewRuntimeForPath(path string) (Runtime, error) {
return NewRuntimeForPathWithLogger(log.StandardLogger(), path)
}
// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path
func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) {
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
}
if info.IsDir() || info.Mode()&0111 == 0 {
return nil, fmt.Errorf("specified path '%v' is not an executable file", path)
}
shim := pathRuntime{
logger: logger,
path: path,
execRuntime: syscallExec{},
}
return &shim, nil
}
// Exec exces into the binary at the path from the pathRuntime struct, passing it the supplied arguments
// after ensuring that the first argument is the path of the target binary.
func (s pathRuntime) Exec(args []string) error {
runtimeArgs := []string{s.path}
if len(args) > 1 {
runtimeArgs = append(runtimeArgs, args[1:]...)
}
return s.execRuntime.Exec(runtimeArgs)
}

View File

@@ -0,0 +1,97 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestPathRuntimeConstructor(t *testing.T) {
r, err := NewRuntimeForPath("////an/invalid/path")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/tmp")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/dev/null")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/bin/sh")
require.NoError(t, err)
f, ok := r.(*pathRuntime)
require.True(t, ok)
require.Equal(t, "/bin/sh", f.path)
}
func TestPathRuntimeForwardsArgs(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
execRuntimeError error
args []string
}{
{},
{
args: []string{"shouldBeReplaced"},
},
{
args: []string{"shouldBeReplaced", "arg1"},
},
{
execRuntimeError: fmt.Errorf("exec error"),
},
}
for _, tc := range testCases {
mockedRuntime := &RuntimeMock{
ExecFunc: func(strings []string) error {
return tc.execRuntimeError
},
}
r := pathRuntime{
logger: logger,
path: "runtime",
execRuntime: mockedRuntime,
}
err := r.Exec(tc.args)
require.ErrorIs(t, err, tc.execRuntimeError)
calls := mockedRuntime.ExecCalls()
require.Len(t, calls, 1)
numArgs := len(tc.args)
if numArgs == 0 {
numArgs = 1
}
require.Len(t, calls[0].Strings, numArgs)
require.Equal(t, "runtime", calls[0].Strings[0])
if numArgs > 1 {
require.EqualValues(t, tc.args[1:], calls[0].Strings[1:])
}
}
}

View File

@@ -0,0 +1,38 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"fmt"
"os"
"syscall"
)
type syscallExec struct{}
var _ Runtime = (*syscallExec)(nil)
func (r syscallExec) Exec(args []string) error {
err := syscall.Exec(args[0], args, os.Environ())
if err != nil {
return fmt.Errorf("could not exec '%v': %v", args[0], err)
}
// syscall.Exec is not expected to return. This is an error state regardless of whether
// err is nil or not.
return fmt.Errorf("unexpected return from exec '%v'", args[0])
}

35
pkg/oci/spec.go Normal file
View File

@@ -0,0 +1,35 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
oci "github.com/opencontainers/runtime-spec/specs-go"
)
// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an
// error. The intention is that the function would modify the spec in-place.
type SpecModifier func(*oci.Spec) error
//go:generate moq -stub -out spec_mock.go . Spec
// Spec defines the operations to be performed on an OCI specification
type Spec interface {
Load() error
Flush() error
Modify(SpecModifier) error
LookupEnv(string) (string, bool)
}

153
pkg/oci/spec_file.go Normal file
View File

@@ -0,0 +1,153 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"encoding/json"
"fmt"
"io"
"os"
"strings"
oci "github.com/opencontainers/runtime-spec/specs-go"
)
type fileSpec struct {
*oci.Spec
path string
}
var _ Spec = (*fileSpec)(nil)
// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the
// application
func NewSpecFromArgs(args []string) (Spec, string, error) {
bundleDir, err := GetBundleDir(args)
if err != nil {
return nil, "", fmt.Errorf("error getting bundle directory: %v", err)
}
ociSpecPath := GetSpecFilePath(bundleDir)
ociSpec := NewSpecFromFile(ociSpecPath)
return ociSpec, bundleDir, nil
}
// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec.
// This can be used to read from the file, modify the spec, and write to the
// same file.
func NewSpecFromFile(filepath string) Spec {
oci := fileSpec{
path: filepath,
}
return &oci
}
// Load reads the contents of an OCI spec from file to be referenced internally.
// The file is opened "read-only"
func (s *fileSpec) Load() error {
specFile, err := os.Open(s.path)
if err != nil {
return fmt.Errorf("error opening OCI specification file: %v", err)
}
defer specFile.Close()
return s.loadFrom(specFile)
}
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
func (s *fileSpec) loadFrom(reader io.Reader) error {
decoder := json.NewDecoder(reader)
var spec oci.Spec
err := decoder.Decode(&spec)
if err != nil {
return fmt.Errorf("error reading OCI specification: %v", err)
}
s.Spec = &spec
return nil
}
// Modify applies the specified SpecModifier to the stored OCI specification.
func (s *fileSpec) Modify(f SpecModifier) error {
if s.Spec == nil {
return fmt.Errorf("no spec loaded for modification")
}
return f(s.Spec)
}
// Flush writes the stored OCI specification to the filepath specifed by the path member.
// The file is truncated upon opening, overwriting any existing contents.
func (s fileSpec) Flush() error {
if s.Spec == nil {
return fmt.Errorf("no OCI specification loaded")
}
specFile, err := os.Create(s.path)
if err != nil {
return fmt.Errorf("error opening OCI specification file: %v", err)
}
defer specFile.Close()
return s.flushTo(specFile)
}
// flushTo writes the stored OCI specification to the specified io.Writer.
func (s fileSpec) flushTo(writer io.Writer) error {
if s.Spec == nil {
return nil
}
encoder := json.NewEncoder(writer)
err := encoder.Encode(s.Spec)
if err != nil {
return fmt.Errorf("error writing OCI specification: %v", err)
}
return nil
}
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
// retrieves the value of the environment variable named
// by the key. If the variable is present in the environment the
// value (which may be empty) is returned and the boolean is true.
// Otherwise the returned value will be empty and the boolean will
// be false.
func (s fileSpec) LookupEnv(key string) (string, bool) {
if s.Spec == nil || s.Spec.Process == nil {
return "", false
}
for _, env := range s.Spec.Process.Env {
if !strings.HasPrefix(env, key) {
continue
}
parts := strings.SplitN(env, "=", 2)
if parts[0] == key {
if len(parts) < 2 {
return "", true
}
return parts[1], true
}
}
return "", false
}

252
pkg/oci/spec_file_test.go Normal file
View File

@@ -0,0 +1,252 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package oci
import (
"bytes"
"fmt"
"testing"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require"
)
func TestLookupEnv(t *testing.T) {
const envName = "TEST_ENV"
testCases := []struct {
spec *specs.Spec
expectedValue string
expectedExits bool
}{
{
// nil spec
spec: nil,
expectedValue: "",
expectedExits: false,
},
{
// nil process
spec: &specs.Spec{},
expectedValue: "",
expectedExits: false,
},
{
// nil env
spec: &specs.Spec{
Process: &specs.Process{},
},
expectedValue: "",
expectedExits: false,
},
{
// empty env
spec: &specs.Spec{
Process: &specs.Process{Env: []string{}},
},
expectedValue: "",
expectedExits: false,
},
{
// different env set
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same prefix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same suffix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// set blank
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV="}},
},
expectedValue: "",
expectedExits: true,
},
{
// set no-equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV"}},
},
expectedValue: "",
expectedExits: true,
},
{
// set value
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
},
expectedValue: "something",
expectedExits: true,
},
{
// set with equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
},
expectedValue: "something=somethingelse",
expectedExits: true,
},
}
for i, tc := range testCases {
spec := fileSpec{
Spec: tc.spec,
}
value, exists := spec.LookupEnv(envName)
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
}
}
func TestLoadFrom(t *testing.T) {
testCases := []struct {
contents []byte
isError bool
spec *specs.Spec
}{
{
contents: []byte{},
isError: true,
},
{
contents: []byte("{}"),
isError: false,
spec: &specs.Spec{},
},
}
for i, tc := range testCases {
spec := fileSpec{}
err := spec.loadFrom(bytes.NewReader(tc.contents))
if tc.isError {
require.Error(t, err, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
}
if tc.spec == nil {
require.Nil(t, spec.Spec, "%d: %v", i, tc)
} else {
require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc)
}
}
}
func TestFlushTo(t *testing.T) {
testCases := []struct {
isError bool
spec *specs.Spec
contents string
}{
{
spec: nil,
},
{
spec: &specs.Spec{},
contents: "{\"ociVersion\":\"\"}\n",
},
}
for i, tc := range testCases {
buffer := bytes.Buffer{}
spec := fileSpec{Spec: tc.spec}
err := spec.flushTo(&buffer)
if tc.isError {
require.Error(t, err, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
}
require.EqualValues(t, tc.contents, buffer.String(), "%d: %v", i, tc)
}
// Add a simple test for a writer that returns an error when writing
spec := fileSpec{Spec: &specs.Spec{}}
err := spec.flushTo(errorWriter{})
require.Error(t, err)
}
func TestModify(t *testing.T) {
testCases := []struct {
spec *specs.Spec
modifierError error
}{
{
spec: nil,
},
{
spec: &specs.Spec{},
},
{
spec: &specs.Spec{},
modifierError: fmt.Errorf("error in modifier"),
},
}
for i, tc := range testCases {
spec := fileSpec{Spec: tc.spec}
modifier := func(spec *specs.Spec) error {
if tc.modifierError == nil {
spec.Version = "updated"
}
return tc.modifierError
}
err := spec.Modify(modifier)
if tc.spec == nil {
require.Error(t, err, "%d: %v", i, tc)
} else if tc.modifierError != nil {
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
}
}
}
// errorWriter implements the io.Writer interface, always returning an error when
// writing.
type errorWriter struct{}
func (e errorWriter) Write([]byte) (int, error) {
return 0, fmt.Errorf("error writing")
}

201
pkg/oci/spec_mock.go Normal file
View File

@@ -0,0 +1,201 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package oci
import (
"sync"
)
// Ensure, that SpecMock does implement Spec.
// If this is not the case, regenerate this file with moq.
var _ Spec = &SpecMock{}
// SpecMock is a mock implementation of Spec.
//
// func TestSomethingThatUsesSpec(t *testing.T) {
//
// // make and configure a mocked Spec
// mockedSpec := &SpecMock{
// FlushFunc: func() error {
// panic("mock out the Flush method")
// },
// LoadFunc: func() error {
// panic("mock out the Load method")
// },
// LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method")
// },
// ModifyFunc: func(specModifier SpecModifier) error {
// panic("mock out the Modify method")
// },
// }
//
// // use mockedSpec in code that requires Spec
// // and then make assertions.
//
// }
type SpecMock struct {
// FlushFunc mocks the Flush method.
FlushFunc func() error
// LoadFunc mocks the Load method.
LoadFunc func() error
// LookupEnvFunc mocks the LookupEnv method.
LookupEnvFunc func(s string) (string, bool)
// ModifyFunc mocks the Modify method.
ModifyFunc func(specModifier SpecModifier) error
// calls tracks calls to the methods.
calls struct {
// Flush holds details about calls to the Flush method.
Flush []struct {
}
// Load holds details about calls to the Load method.
Load []struct {
}
// LookupEnv holds details about calls to the LookupEnv method.
LookupEnv []struct {
// S is the s argument value.
S string
}
// Modify holds details about calls to the Modify method.
Modify []struct {
// SpecModifier is the specModifier argument value.
SpecModifier SpecModifier
}
}
lockFlush sync.RWMutex
lockLoad sync.RWMutex
lockLookupEnv sync.RWMutex
lockModify sync.RWMutex
}
// Flush calls FlushFunc.
func (mock *SpecMock) Flush() error {
callInfo := struct {
}{}
mock.lockFlush.Lock()
mock.calls.Flush = append(mock.calls.Flush, callInfo)
mock.lockFlush.Unlock()
if mock.FlushFunc == nil {
var (
errOut error
)
return errOut
}
return mock.FlushFunc()
}
// FlushCalls gets all the calls that were made to Flush.
// Check the length with:
// len(mockedSpec.FlushCalls())
func (mock *SpecMock) FlushCalls() []struct {
} {
var calls []struct {
}
mock.lockFlush.RLock()
calls = mock.calls.Flush
mock.lockFlush.RUnlock()
return calls
}
// Load calls LoadFunc.
func (mock *SpecMock) Load() error {
callInfo := struct {
}{}
mock.lockLoad.Lock()
mock.calls.Load = append(mock.calls.Load, callInfo)
mock.lockLoad.Unlock()
if mock.LoadFunc == nil {
var (
errOut error
)
return errOut
}
return mock.LoadFunc()
}
// LoadCalls gets all the calls that were made to Load.
// Check the length with:
// len(mockedSpec.LoadCalls())
func (mock *SpecMock) LoadCalls() []struct {
} {
var calls []struct {
}
mock.lockLoad.RLock()
calls = mock.calls.Load
mock.lockLoad.RUnlock()
return calls
}
// LookupEnv calls LookupEnvFunc.
func (mock *SpecMock) LookupEnv(s string) (string, bool) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLookupEnv.Lock()
mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo)
mock.lockLookupEnv.Unlock()
if mock.LookupEnvFunc == nil {
var (
sOut string
bOut bool
)
return sOut, bOut
}
return mock.LookupEnvFunc(s)
}
// LookupEnvCalls gets all the calls that were made to LookupEnv.
// Check the length with:
// len(mockedSpec.LookupEnvCalls())
func (mock *SpecMock) LookupEnvCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLookupEnv.RLock()
calls = mock.calls.LookupEnv
mock.lockLookupEnv.RUnlock()
return calls
}
// Modify calls ModifyFunc.
func (mock *SpecMock) Modify(specModifier SpecModifier) error {
callInfo := struct {
SpecModifier SpecModifier
}{
SpecModifier: specModifier,
}
mock.lockModify.Lock()
mock.calls.Modify = append(mock.calls.Modify, callInfo)
mock.lockModify.Unlock()
if mock.ModifyFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ModifyFunc(specModifier)
}
// ModifyCalls gets all the calls that were made to Modify.
// Check the length with:
// len(mockedSpec.ModifyCalls())
func (mock *SpecMock) ModifyCalls() []struct {
SpecModifier SpecModifier
} {
var calls []struct {
SpecModifier SpecModifier
}
mock.lockModify.RLock()
calls = mock.calls.Modify
mock.lockModify.RUnlock()
return calls
}

View File

@@ -0,0 +1,82 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
log "github.com/sirupsen/logrus"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/modify"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
)
type modifyingRuntimeWrapper struct {
logger *log.Logger
runtime oci.Runtime
ociSpec oci.Spec
modifier modify.Modifier
}
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
// NewModifyingRuntimeWrapperWithLogger creates a runtime wrapper that applies the specified modifier to the OCI specification
// before invoking the wrapped runtime.
func NewModifyingRuntimeWrapperWithLogger(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier modify.Modifier) oci.Runtime {
rt := modifyingRuntimeWrapper{
logger: logger,
runtime: runtime,
ociSpec: spec,
modifier: modifier,
}
return &rt
}
// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing
// into the wrapped runtime.
func (r *modifyingRuntimeWrapper) Exec(args []string) error {
if oci.HasCreateSubcommand(args) {
err := r.modify()
if err != nil {
return fmt.Errorf("could not apply required modification to OCI specification: %v", err)
}
r.logger.Infof("Applied required modification to OCI specification")
} else {
r.logger.Infof("No modification of OCI specification required")
}
r.logger.Infof("Forwarding command to runtime")
return r.runtime.Exec(args)
}
// modify loads, modifies, and flushes the OCI specification using the defined Modifier
func (r *modifyingRuntimeWrapper) modify() error {
err := r.ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification for modification: %v", err)
}
err = r.modifier.Modify(r.ociSpec)
if err != nil {
return fmt.Errorf("error modifying OCI spec: %v", err)
}
err = r.ociSpec.Flush()
if err != nil {
return fmt.Errorf("error writing modified OCI specification: %v", err)
}
return nil
}

View File

@@ -0,0 +1,160 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/modify"
"gitlab.com/nvidia/cloud-native/container-toolkit/pkg/oci"
)
func TestRuntimeModifier(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
args []string
shouldModify bool
}{
{},
{
args: []string{"create"},
shouldModify: true,
},
}
for _, tc := range testCases {
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec(tc.args)
require.NoError(t, err)
expectedCalls := 0
if tc.shouldModify {
expectedCalls = 1
}
require.Len(t, specMock.LoadCalls(), expectedCalls)
require.Len(t, modifierMock.ModifyCalls(), expectedCalls)
require.Len(t, specMock.FlushCalls(), expectedCalls)
require.Len(t, runtimeMock.ExecCalls(), 1)
}
}
func TestRuntimeModiferWithLoadError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{
LoadFunc: specErrorFunc,
}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 0)
require.Len(t, specMock.FlushCalls(), 0)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func TestRuntimeModiferWithFlushError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{
FlushFunc: specErrorFunc,
}
modifierMock := &modify.ModifierMock{}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 1)
require.Len(t, specMock.FlushCalls(), 1)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func TestRuntimeModiferWithModifyError(t *testing.T) {
logger, _ := testlog.NewNullLogger()
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{}
modifierMock := &modify.ModifierMock{
ModifyFunc: modifierErrorFunc,
}
r := NewModifyingRuntimeWrapperWithLogger(
logger,
runtimeMock,
specMock,
modifierMock,
)
err := r.Exec([]string{"create"})
require.Error(t, err)
require.Len(t, specMock.LoadCalls(), 1)
require.Len(t, modifierMock.ModifyCalls(), 1)
require.Len(t, specMock.FlushCalls(), 0)
require.Len(t, runtimeMock.ExecCalls(), 0)
}
func specErrorFunc() error {
return fmt.Errorf("error")
}
func modifierErrorFunc(oci.Spec) error {
return fmt.Errorf("error")
}