mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-24 13:05:17 +00:00
Copy code from nvidia-container-runtime
This change copies the cmd/nvidia-container-runtime, internal, and test folders from github.com/NVIDIA/nvidia-container-runtime@8a63b4b34f3ce3b4167f0516aa3f7207ca280dfb Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
58e707fed6
commit
b6a585c77d
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
dist
|
||||
*.swp
|
||||
*.swo
|
||||
/coverage.out
|
||||
/coverage.out
|
||||
/test/output/
|
||||
|
79
cmd/nvidia-container-runtime/logger.go
Normal file
79
cmd/nvidia-container-runtime/logger.go
Normal file
@ -0,0 +1,79 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tsaikd/KDGoLib/logrusutil"
|
||||
)
|
||||
|
||||
// Logger adds a way to manage output to a log file to a logrus.Logger
|
||||
type Logger struct {
|
||||
*logrus.Logger
|
||||
previousOutput io.Writer
|
||||
logFile *os.File
|
||||
}
|
||||
|
||||
// NewLogger constructs a Logger with a preddefined formatter
|
||||
func NewLogger() *Logger {
|
||||
logrusLogger := logrus.New()
|
||||
|
||||
formatter := &logrusutil.ConsoleLogFormatter{
|
||||
TimestampFormat: "2006/01/02 15:04:07",
|
||||
Flag: logrusutil.Ltime,
|
||||
}
|
||||
|
||||
logger := &Logger{
|
||||
Logger: logrusLogger,
|
||||
}
|
||||
logger.SetFormatter(formatter)
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// LogToFile opens the specified file for appending and sets the logger to
|
||||
// output to the opened file. A reference to the file pointer is stored to
|
||||
// allow this to be closed.
|
||||
func (l *Logger) LogToFile(filename string) error {
|
||||
logFile, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening debug log file: %v", err)
|
||||
}
|
||||
|
||||
l.logFile = logFile
|
||||
l.previousOutput = l.Out
|
||||
l.SetOutput(logFile)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseFile closes the log file (if any) and resets the logger output to what it
|
||||
// was before LogToFile was called.
|
||||
func (l *Logger) CloseFile() error {
|
||||
if l.logFile == nil {
|
||||
return nil
|
||||
}
|
||||
logFile := l.logFile
|
||||
l.SetOutput(l.previousOutput)
|
||||
l.logFile = nil
|
||||
|
||||
return logFile.Close()
|
||||
}
|
89
cmd/nvidia-container-runtime/main.go
Normal file
89
cmd/nvidia-container-runtime/main.go
Normal file
@ -0,0 +1,89 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/pelletier/go-toml"
|
||||
)
|
||||
|
||||
const (
|
||||
configOverride = "XDG_CONFIG_HOME"
|
||||
configFilePath = "nvidia-container-runtime/config.toml"
|
||||
|
||||
hookDefaultFilePath = "/usr/bin/nvidia-container-runtime-hook"
|
||||
)
|
||||
|
||||
var (
|
||||
configDir = "/etc/"
|
||||
)
|
||||
|
||||
var logger = NewLogger()
|
||||
|
||||
func main() {
|
||||
err := run(os.Args)
|
||||
if err != nil {
|
||||
logger.Errorf("Error running %v: %v", os.Args, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// run is an entry point that allows for idiomatic handling of errors
|
||||
// when calling from the main function.
|
||||
func run(argv []string) (err error) {
|
||||
cfg, err := getConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %v", err)
|
||||
}
|
||||
|
||||
err = logger.LogToFile(cfg.debugFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening debug log file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
// We capture and log a returning error before closing the log file.
|
||||
if err != nil {
|
||||
logger.Errorf("Error running %v: %v", argv, err)
|
||||
}
|
||||
logger.CloseFile()
|
||||
}()
|
||||
|
||||
r, err := newRuntime(argv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating runtime: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("Running %s\n", argv[0])
|
||||
return r.Exec(argv)
|
||||
}
|
||||
|
||||
type config struct {
|
||||
debugFilePath string
|
||||
}
|
||||
|
||||
// getConfig sets up the config struct. Values are read from a toml file
|
||||
// or set via the environment.
|
||||
func getConfig() (*config, error) {
|
||||
cfg := &config{}
|
||||
|
||||
if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 {
|
||||
configDir = XDGConfigDir
|
||||
}
|
||||
|
||||
configFilePath := path.Join(configDir, configFilePath)
|
||||
|
||||
tomlContent, err := os.ReadFile(configFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toml, err := toml.Load(string(tomlContent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.debugFilePath = toml.GetDefault("nvidia-container-runtime.debug", "/dev/null").(string)
|
||||
|
||||
return cfg, nil
|
||||
}
|
293
cmd/nvidia-container-runtime/main_test.go
Normal file
293
cmd/nvidia-container-runtime/main_test.go
Normal file
@ -0,0 +1,293 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
nvidiaRuntime = "nvidia-container-runtime"
|
||||
nvidiaHook = "nvidia-container-runtime-hook"
|
||||
bundlePathSuffix = "test/output/bundle/"
|
||||
specFile = "config.json"
|
||||
unmodifiedSpecFileSuffix = "test/input/test_spec.json"
|
||||
)
|
||||
|
||||
type testConfig struct {
|
||||
root string
|
||||
binPath string
|
||||
}
|
||||
|
||||
var cfg *testConfig
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// TEST SETUP
|
||||
// Determine the module root and the test binary path
|
||||
var err error
|
||||
moduleRoot, err := getModuleRoot()
|
||||
if err != nil {
|
||||
logger.Fatalf("error in test setup: could not get module root: %v", err)
|
||||
}
|
||||
testBinPath := filepath.Join(moduleRoot, "test", "bin")
|
||||
testInputPath := filepath.Join(moduleRoot, "test", "input")
|
||||
|
||||
// Set the environment variables for the test
|
||||
os.Setenv("PATH", prependToPath(testBinPath, moduleRoot))
|
||||
os.Setenv("XDG_CONFIG_HOME", testInputPath)
|
||||
|
||||
// Confirm that the environment is configured correctly
|
||||
runcPath, err := exec.LookPath(runcExecutableName)
|
||||
if err != nil || filepath.Join(testBinPath, runcExecutableName) != runcPath {
|
||||
logger.Fatalf("error in test setup: mock runc path set incorrectly in TestMain(): %v", err)
|
||||
}
|
||||
hookPath, err := exec.LookPath(nvidiaHook)
|
||||
if err != nil || filepath.Join(testBinPath, nvidiaHook) != hookPath {
|
||||
logger.Fatalf("error in test setup: mock hook path set incorrectly in TestMain(): %v", err)
|
||||
}
|
||||
|
||||
// Store the root and binary paths in the test Config
|
||||
cfg = &testConfig{
|
||||
root: moduleRoot,
|
||||
binPath: testBinPath,
|
||||
}
|
||||
|
||||
// RUN TESTS
|
||||
exitCode := m.Run()
|
||||
|
||||
// TEST CLEANUP
|
||||
os.Remove(specFile)
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func getModuleRoot() (string, error) {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
|
||||
return hasGoMod(filename)
|
||||
}
|
||||
|
||||
func hasGoMod(dir string) (string, error) {
|
||||
if dir == "" || dir == "/" {
|
||||
return "", fmt.Errorf("module root not found")
|
||||
}
|
||||
|
||||
_, err := os.Stat(filepath.Join(dir, "go.mod"))
|
||||
if err != nil {
|
||||
return hasGoMod(filepath.Dir(dir))
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func prependToPath(additionalPaths ...string) string {
|
||||
paths := strings.Split(os.Getenv("PATH"), ":")
|
||||
paths = append(additionalPaths, paths...)
|
||||
|
||||
return strings.Join(paths, ":")
|
||||
}
|
||||
|
||||
// case 1) nvidia-container-runtime run --bundle
|
||||
// case 2) nvidia-container-runtime create --bundle
|
||||
// - Confirm the runtime handles bad input correctly
|
||||
func TestBadInput(t *testing.T) {
|
||||
err := cfg.generateNewRuntimeSpec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmdRun := exec.Command(nvidiaRuntime, "run", "--bundle")
|
||||
t.Logf("executing: %s\n", strings.Join(cmdRun.Args, " "))
|
||||
output, err := cmdRun.CombinedOutput()
|
||||
require.Errorf(t, err, "runtime should return an error", "output=%v", string(output))
|
||||
|
||||
cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle")
|
||||
t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " "))
|
||||
err = cmdCreate.Run()
|
||||
require.Error(t, err, "runtime should return an error")
|
||||
}
|
||||
|
||||
// case 1) nvidia-container-runtime run --bundle <bundle-name> <ctr-name>
|
||||
// - Confirm the runtime runs with no errors
|
||||
// case 2) nvidia-container-runtime create --bundle <bundle-name> <ctr-name>
|
||||
// - Confirm the runtime inserts the NVIDIA prestart hook correctly
|
||||
func TestGoodInput(t *testing.T) {
|
||||
err := cfg.generateNewRuntimeSpec()
|
||||
if err != nil {
|
||||
t.Fatalf("error generating runtime spec: %v", err)
|
||||
}
|
||||
|
||||
cmdRun := exec.Command(nvidiaRuntime, "run", "--bundle", cfg.bundlePath(), "testcontainer")
|
||||
t.Logf("executing: %s\n", strings.Join(cmdRun.Args, " "))
|
||||
output, err := cmdRun.CombinedOutput()
|
||||
require.NoErrorf(t, err, "runtime should not return an error", "output=%v", string(output))
|
||||
|
||||
// Check config.json and confirm there are no hooks
|
||||
spec, err := cfg.getRuntimeSpec()
|
||||
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
|
||||
require.Empty(t, spec.Hooks, "there should be no hooks in config.json")
|
||||
|
||||
cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer")
|
||||
t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " "))
|
||||
err = cmdCreate.Run()
|
||||
require.NoError(t, err, "runtime should not return an error")
|
||||
|
||||
// Check config.json for NVIDIA prestart hook
|
||||
spec, err = cfg.getRuntimeSpec()
|
||||
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
|
||||
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
|
||||
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
|
||||
}
|
||||
|
||||
// NVIDIA prestart hook already present in config file
|
||||
func TestDuplicateHook(t *testing.T) {
|
||||
err := cfg.generateNewRuntimeSpec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var spec specs.Spec
|
||||
spec, err = cfg.getRuntimeSpec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("inserting nvidia prestart hook to config.json")
|
||||
if err = addNVIDIAHook(&spec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonOutput, err := json.MarshalIndent(spec, "", "\t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonFile, err := os.OpenFile(cfg.specFilePath(), os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = jsonFile.WriteAt(jsonOutput, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test how runtime handles already existing prestart hook in config.json
|
||||
cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer")
|
||||
t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " "))
|
||||
output, err := cmdCreate.CombinedOutput()
|
||||
require.NoErrorf(t, err, "runtime should not return an error", "output=%v", string(output))
|
||||
|
||||
// Check config.json for NVIDIA prestart hook
|
||||
spec, err = cfg.getRuntimeSpec()
|
||||
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
|
||||
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
|
||||
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
|
||||
}
|
||||
|
||||
// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for
|
||||
// testing.
|
||||
func addNVIDIAHook(spec *specs.Spec) error {
|
||||
r := nvidiaContainerRuntime{logger: logger.Logger}
|
||||
return r.addNVIDIAHook(spec)
|
||||
}
|
||||
|
||||
func (c testConfig) getRuntimeSpec() (specs.Spec, error) {
|
||||
filePath := c.specFilePath()
|
||||
|
||||
var spec specs.Spec
|
||||
jsonFile, err := os.OpenFile(filePath, os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
return spec, err
|
||||
}
|
||||
defer jsonFile.Close()
|
||||
|
||||
jsonContent, err := ioutil.ReadAll(jsonFile)
|
||||
if err != nil {
|
||||
return spec, err
|
||||
} else if json.Valid(jsonContent) {
|
||||
err = json.Unmarshal(jsonContent, &spec)
|
||||
if err != nil {
|
||||
return spec, err
|
||||
}
|
||||
} else {
|
||||
err = json.NewDecoder(bytes.NewReader(jsonContent)).Decode(&spec)
|
||||
if err != nil {
|
||||
return spec, err
|
||||
}
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
func (c testConfig) bundlePath() string {
|
||||
return filepath.Join(c.root, bundlePathSuffix)
|
||||
}
|
||||
|
||||
func (c testConfig) specFilePath() string {
|
||||
return filepath.Join(c.bundlePath(), specFile)
|
||||
}
|
||||
|
||||
func (c testConfig) unmodifiedSpecFile() string {
|
||||
return filepath.Join(c.root, unmodifiedSpecFileSuffix)
|
||||
}
|
||||
|
||||
func (c testConfig) generateNewRuntimeSpec() error {
|
||||
var err error
|
||||
|
||||
err = os.MkdirAll(c.bundlePath(), 0755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("cp", c.unmodifiedSpecFile(), c.specFilePath())
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return number of valid NVIDIA prestart hooks in runtime spec
|
||||
func nvidiaHookCount(hooks *specs.Hooks) int {
|
||||
if hooks == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, hook := range hooks.Prestart {
|
||||
if strings.Contains(hook.Path, nvidiaHook) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func TestGetConfigWithCustomConfig(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
// By default debug is disabled
|
||||
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")
|
||||
testDir := filepath.Join(wd, "test")
|
||||
filename := filepath.Join(testDir, configFilePath)
|
||||
|
||||
os.Setenv(configOverride, testDir)
|
||||
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
|
||||
require.NoError(t, ioutil.WriteFile(filename, contents, 0766))
|
||||
|
||||
defer func() { require.NoError(t, os.RemoveAll(testDir)) }()
|
||||
|
||||
cfg, err := getConfig()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cfg.debugFilePath, "/nvidia-container-toolkit.log")
|
||||
}
|
145
cmd/nvidia-container-runtime/nvcr.go
Normal file
145
cmd/nvidia-container-runtime/nvcr.go
Normal file
@ -0,0 +1,145 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-runtime/internal/oci"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally
|
||||
// modifying the specified OCI specification before invoking the runtime.
|
||||
type nvidiaContainerRuntime struct {
|
||||
logger *log.Logger
|
||||
runtime oci.Runtime
|
||||
ociSpec oci.Spec
|
||||
}
|
||||
|
||||
var _ oci.Runtime = (*nvidiaContainerRuntime)(nil)
|
||||
|
||||
// newNvidiaContainerRuntime is a constructor for a standard runtime shim.
|
||||
func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
|
||||
r := nvidiaContainerRuntime{
|
||||
logger: logger,
|
||||
runtime: runtime,
|
||||
ociSpec: ociSpec,
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications
|
||||
// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then
|
||||
// forwarded to the underlying runtime's Exec method.
|
||||
func (r nvidiaContainerRuntime) Exec(args []string) error {
|
||||
if r.modificationRequired(args) {
|
||||
err := r.modifyOCISpec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying OCI spec: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Println("Forwarding command to runtime")
|
||||
return r.runtime.Exec(args)
|
||||
}
|
||||
|
||||
// modificationRequired checks the intput arguments to determine whether a modification
|
||||
// to the OCI spec is required.
|
||||
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
|
||||
var previousWasBundle bool
|
||||
for _, a := range args {
|
||||
// We check for '--bundle create' explicitly to ensure that we
|
||||
// don't inadvertently trigger a modification if the bundle directory
|
||||
// is specified as `create`
|
||||
if !previousWasBundle && isBundleFlag(a) {
|
||||
previousWasBundle = true
|
||||
continue
|
||||
}
|
||||
|
||||
if !previousWasBundle && a == "create" {
|
||||
r.logger.Infof("'create' command detected; modification required")
|
||||
return true
|
||||
}
|
||||
|
||||
previousWasBundle = false
|
||||
}
|
||||
|
||||
r.logger.Infof("No modification required")
|
||||
return false
|
||||
}
|
||||
|
||||
// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime
|
||||
// struct. The spec is modified in-place and written to the same file as the input after
|
||||
// modifcationas are applied.
|
||||
func (r nvidiaContainerRuntime) modifyOCISpec() error {
|
||||
err := r.ociSpec.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading OCI specification for modification: %v", err)
|
||||
}
|
||||
|
||||
err = r.ociSpec.Modify(r.addNVIDIAHook)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err)
|
||||
}
|
||||
|
||||
err = r.ociSpec.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing modified OCI specification: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNVIDIAHook modifies the specified OCI specification in-place, inserting a
|
||||
// prestart hook.
|
||||
func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
|
||||
path, err := exec.LookPath("nvidia-container-runtime-hook")
|
||||
if err != nil {
|
||||
path = hookDefaultFilePath
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Printf("prestart hook path: %s\n", path)
|
||||
|
||||
args := []string{path}
|
||||
if spec.Hooks == nil {
|
||||
spec.Hooks = &specs.Hooks{}
|
||||
} else if len(spec.Hooks.Prestart) != 0 {
|
||||
for _, hook := range spec.Hooks.Prestart {
|
||||
if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
|
||||
continue
|
||||
}
|
||||
r.logger.Println("existing nvidia prestart hook in OCI spec file")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
spec.Hooks.Prestart = append(spec.Hooks.Prestart, specs.Hook{
|
||||
Path: path,
|
||||
Args: append(args, "prestart"),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
230
cmd/nvidia-container-runtime/nvcr_test.go
Normal file
230
cmd/nvidia-container-runtime/nvcr_test.go
Normal file
@ -0,0 +1,230 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-runtime/internal/oci"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestArgsGetConfigFilePath(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
bundleDir string
|
||||
ociSpecPath string
|
||||
}{
|
||||
{
|
||||
ociSpecPath: fmt.Sprintf("%v/config.json", wd),
|
||||
},
|
||||
{
|
||||
bundleDir: "/foo/bar",
|
||||
ociSpecPath: "/foo/bar/config.json",
|
||||
},
|
||||
{
|
||||
bundleDir: "/foo/bar/",
|
||||
ociSpecPath: "/foo/bar/config.json",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
cp, err := getOCISpecFilePath(tc.bundleDir)
|
||||
|
||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||
require.Equalf(t, tc.ociSpecPath, cp, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddNvidiaHook(t *testing.T) {
|
||||
logger, logHook := testlog.NewNullLogger()
|
||||
shim := nvidiaContainerRuntime{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
errorPrefix string
|
||||
shouldNotAdd bool
|
||||
}{
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{
|
||||
Hooks: &specs.Hooks{},
|
||||
},
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{
|
||||
Hooks: &specs.Hooks{
|
||||
Prestart: []specs.Hook{{
|
||||
Path: "some-hook",
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{
|
||||
Hooks: &specs.Hooks{
|
||||
Prestart: []specs.Hook{{
|
||||
Path: "nvidia-container-runtime-hook",
|
||||
}},
|
||||
},
|
||||
},
|
||||
shouldNotAdd: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
logHook.Reset()
|
||||
|
||||
var numPrestartHooks int
|
||||
if tc.spec.Hooks != nil {
|
||||
numPrestartHooks = len(tc.spec.Hooks.Prestart)
|
||||
}
|
||||
|
||||
err := shim.addNVIDIAHook(tc.spec)
|
||||
|
||||
if tc.errorPrefix == "" {
|
||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc)
|
||||
|
||||
require.NotNilf(t, tc.spec.Hooks, "%d: %v", i, tc)
|
||||
require.Equalf(t, 1, nvidiaHookCount(tc.spec.Hooks), "%d: %v", i, tc)
|
||||
|
||||
if tc.shouldNotAdd {
|
||||
require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc)
|
||||
} else {
|
||||
require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc)
|
||||
|
||||
nvidiaHook := tc.spec.Hooks.Poststart[len(tc.spec.Hooks.Poststart)-1]
|
||||
|
||||
// TODO: This assumes that the hook has been set up in the makefile
|
||||
expectedPath := "/usr/bin/nvidia-container-runtime-hook"
|
||||
require.Equalf(t, expectedPath, nvidiaHook.Path, "%d: %v", i, tc)
|
||||
require.Equalf(t, []string{expectedPath, "prestart"}, nvidiaHook.Args, "%d: %v", i, tc)
|
||||
require.Emptyf(t, nvidiaHook.Env, "%d: %v", i, tc)
|
||||
require.Nilf(t, nvidiaHook.Timeout, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNvidiaContainerRuntime(t *testing.T) {
|
||||
logger, hook := testlog.NewNullLogger()
|
||||
|
||||
testCases := []struct {
|
||||
shim nvidiaContainerRuntime
|
||||
shouldModify bool
|
||||
args []string
|
||||
modifyError error
|
||||
writeError error
|
||||
}{
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"--bundle=create"},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"--bundle", "create"},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
modifyError: fmt.Errorf("error modifying"),
|
||||
shouldModify: true,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
writeError: fmt.Errorf("error writing"),
|
||||
shouldModify: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
tc.shim.logger = logger
|
||||
hook.Reset()
|
||||
|
||||
spec := &specs.Spec{}
|
||||
ociMock := oci.NewMockSpec(spec, tc.writeError, tc.modifyError)
|
||||
|
||||
require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)
|
||||
|
||||
tc.shim.ociSpec = ociMock
|
||||
tc.shim.runtime = &MockShim{}
|
||||
|
||||
err := tc.shim.Exec(tc.args)
|
||||
if tc.modifyError != nil || tc.writeError != nil {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoError(t, err, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
if tc.shouldModify {
|
||||
require.Equal(t, 1, ociMock.MockModify.Callcount, "%d: %v", i, tc)
|
||||
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "%d: %v", i, tc)
|
||||
} else {
|
||||
require.Equal(t, 0, ociMock.MockModify.Callcount, "%d: %v", i, tc)
|
||||
require.Nil(t, spec.Hooks, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
writeExpected := tc.shouldModify && tc.modifyError == nil
|
||||
if writeExpected {
|
||||
require.Equal(t, 1, ociMock.MockFlush.Callcount, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.Equal(t, 0, ociMock.MockFlush.Callcount, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MockShim struct {
|
||||
called bool
|
||||
args []string
|
||||
returnError error
|
||||
}
|
||||
|
||||
func (m *MockShim) Exec(args []string) error {
|
||||
m.called = true
|
||||
m.args = args
|
||||
return m.returnError
|
||||
}
|
176
cmd/nvidia-container-runtime/runtime_factory.go
Normal file
176
cmd/nvidia-container-runtime/runtime_factory.go
Normal file
@ -0,0 +1,176 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-runtime/internal/oci"
|
||||
)
|
||||
|
||||
const (
|
||||
ociSpecFileName = "config.json"
|
||||
dockerRuncExecutableName = "docker-runc"
|
||||
runcExecutableName = "runc"
|
||||
)
|
||||
|
||||
// newRuntime is a factory method that constructs a runtime based on the selected configuration.
|
||||
func newRuntime(argv []string) (oci.Runtime, error) {
|
||||
ociSpec, err := newOCISpec(argv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
|
||||
}
|
||||
|
||||
runc, err := newRuncRuntime()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing runc runtime: %v", err)
|
||||
}
|
||||
|
||||
r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// newOCISpec constructs an OCI spec for the provided arguments
|
||||
func newOCISpec(argv []string) (oci.Spec, error) {
|
||||
bundlePath, err := getBundlePath(argv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
|
||||
}
|
||||
|
||||
ociSpecPath, err := getOCISpecFilePath(bundlePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting OCI specification file path: %v", err)
|
||||
}
|
||||
ociSpec := oci.NewSpecFromFile(ociSpecPath)
|
||||
|
||||
return ociSpec, nil
|
||||
}
|
||||
|
||||
// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
|
||||
func newRuncRuntime() (oci.Runtime, error) {
|
||||
runtimePath, err := findRunc()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error locating runtime: %v", err)
|
||||
}
|
||||
|
||||
runc, err := oci.NewSyscallExecRuntimeWithLogger(logger.Logger, runtimePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing runtime: %v", err)
|
||||
}
|
||||
|
||||
return runc, nil
|
||||
}
|
||||
|
||||
// getBundlePath checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
|
||||
// The following are supported:
|
||||
// --bundle{{SEP}}BUNDLE_PATH
|
||||
// -bundle{{SEP}}BUNDLE_PATH
|
||||
// -b{{SEP}}BUNDLE_PATH
|
||||
// where {{SEP}} is either ' ' or '='
|
||||
func getBundlePath(argv []string) (string, error) {
|
||||
var bundlePath string
|
||||
|
||||
for i := 0; i < len(argv); i++ {
|
||||
param := argv[i]
|
||||
|
||||
parts := strings.SplitN(param, "=", 2)
|
||||
if !isBundleFlag(parts[0]) {
|
||||
continue
|
||||
}
|
||||
|
||||
// The flag has the format --bundle=/path
|
||||
if len(parts) == 2 {
|
||||
bundlePath = parts[1]
|
||||
continue
|
||||
}
|
||||
|
||||
// The flag has the format --bundle /path
|
||||
if i+1 < len(argv) {
|
||||
bundlePath = argv[i+1]
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// --bundle / -b was the last element of argv
|
||||
return "", fmt.Errorf("bundle option requires an argument")
|
||||
}
|
||||
|
||||
return bundlePath, nil
|
||||
}
|
||||
|
||||
// findRunc locates runc in the path, returning the full path to the
|
||||
// binary or an error.
|
||||
func findRunc() (string, error) {
|
||||
runtimeCandidates := []string{
|
||||
dockerRuncExecutableName,
|
||||
runcExecutableName,
|
||||
}
|
||||
|
||||
return findRuntime(runtimeCandidates)
|
||||
}
|
||||
|
||||
func findRuntime(runtimeCandidates []string) (string, error) {
|
||||
for _, candidate := range runtimeCandidates {
|
||||
logger.Infof("Looking for runtime binary '%v'", candidate)
|
||||
runcPath, err := exec.LookPath(candidate)
|
||||
if err == nil {
|
||||
logger.Infof("Found runtime binary '%v'", runcPath)
|
||||
return runcPath, nil
|
||||
}
|
||||
logger.Warnf("Runtime binary '%v' not found: %v", candidate, err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no runtime binary found from candidate list: %v", runtimeCandidates)
|
||||
}
|
||||
|
||||
func isBundleFlag(arg string) bool {
|
||||
if !strings.HasPrefix(arg, "-") {
|
||||
return false
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeft(arg, "-")
|
||||
return trimmed == "b" || trimmed == "bundle"
|
||||
}
|
||||
|
||||
// getOCISpecFilePath returns the expected path to the OCI specification file for the given
|
||||
// bundle directory or the current working directory if not specified.
|
||||
func getOCISpecFilePath(bundleDir string) (string, error) {
|
||||
if bundleDir == "" {
|
||||
logger.Infof("Bundle directory path is empty, using working directory.")
|
||||
workingDirectory, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting working directory: %v", err)
|
||||
}
|
||||
bundleDir = workingDirectory
|
||||
}
|
||||
|
||||
logger.Infof("Using bundle directory: %v", bundleDir)
|
||||
|
||||
OCISpecFilePath := filepath.Join(bundleDir, ociSpecFileName)
|
||||
|
||||
logger.Infof("Using OCI specification file path: %v", OCISpecFilePath)
|
||||
|
||||
return OCISpecFilePath, nil
|
||||
}
|
192
cmd/nvidia-container-runtime/runtime_factory_test.go
Normal file
192
cmd/nvidia-container-runtime/runtime_factory_test.go
Normal file
@ -0,0 +1,192 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConstructor(t *testing.T) {
|
||||
shim, err := newRuntime([]string{})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, shim)
|
||||
}
|
||||
|
||||
func TestGetBundlePath(t *testing.T) {
|
||||
type expected struct {
|
||||
bundle string
|
||||
isError bool
|
||||
}
|
||||
testCases := []struct {
|
||||
argv []string
|
||||
expected expected
|
||||
}{
|
||||
{
|
||||
argv: []string{},
|
||||
},
|
||||
{
|
||||
argv: []string{"create"},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle"},
|
||||
expected: expected{
|
||||
isError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b"},
|
||||
expected: expected{
|
||||
isError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--not-bundle", "/foo/bar"},
|
||||
},
|
||||
{
|
||||
argv: []string{"--"},
|
||||
},
|
||||
{
|
||||
argv: []string{"-bundle", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle=/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=/foo/=bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/=bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"create", "-b", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "create", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=create", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
bundle, err := getBundlePath(tc.argv)
|
||||
|
||||
if tc.expected.isError {
|
||||
require.Errorf(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRunc(t *testing.T) {
|
||||
testLogger, _ := testlog.NewNullLogger()
|
||||
logger.Logger = testLogger
|
||||
|
||||
runcPath, err := findRunc()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, filepath.Join(cfg.binPath, runcExecutableName), runcPath)
|
||||
}
|
||||
|
||||
func TestFindRuntime(t *testing.T) {
|
||||
testLogger, _ := testlog.NewNullLogger()
|
||||
logger.Logger = testLogger
|
||||
|
||||
testCases := []struct {
|
||||
candidates []string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
candidates: []string{},
|
||||
},
|
||||
{
|
||||
candidates: []string{"not-runc"},
|
||||
},
|
||||
{
|
||||
candidates: []string{"not-runc", "also-not-runc"},
|
||||
},
|
||||
{
|
||||
candidates: []string{runcExecutableName},
|
||||
expectedPath: filepath.Join(cfg.binPath, runcExecutableName),
|
||||
},
|
||||
{
|
||||
candidates: []string{runcExecutableName, "not-runc"},
|
||||
expectedPath: filepath.Join(cfg.binPath, runcExecutableName),
|
||||
},
|
||||
{
|
||||
candidates: []string{"not-runc", runcExecutableName},
|
||||
expectedPath: filepath.Join(cfg.binPath, runcExecutableName),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
runcPath, err := findRuntime(tc.candidates)
|
||||
if tc.expectedPath == "" {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoError(t, err, "%d: %v", i, tc)
|
||||
}
|
||||
require.Equal(t, tc.expectedPath, runcPath, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
}
|
23
internal/oci/runtime.go
Normal file
23
internal/oci/runtime.go
Normal file
@ -0,0 +1,23 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
// Runtime is an interface for a runtime shim. The Exec method accepts a list
|
||||
// of command line arguments, and returns an error / nil.
|
||||
type Runtime interface {
|
||||
Exec([]string) error
|
||||
}
|
79
internal/oci/runtime_exec.go
Normal file
79
internal/oci/runtime_exec.go
Normal file
@ -0,0 +1,79 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SyscallExecRuntime wraps the path that a binary and defines the semanitcs for how to exec into it.
|
||||
// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the
|
||||
// Runtime internface.
|
||||
type SyscallExecRuntime struct {
|
||||
logger *log.Logger
|
||||
path string
|
||||
// exec is used for testing. This defaults to syscall.Exec
|
||||
exec func(argv0 string, argv []string, envv []string) error
|
||||
}
|
||||
|
||||
var _ Runtime = (*SyscallExecRuntime)(nil)
|
||||
|
||||
// NewSyscallExecRuntime creates a SyscallExecRuntime for the specified path with the standard logger
|
||||
func NewSyscallExecRuntime(path string) (Runtime, error) {
|
||||
return NewSyscallExecRuntimeWithLogger(log.StandardLogger(), path)
|
||||
}
|
||||
|
||||
// NewSyscallExecRuntimeWithLogger creates a SyscallExecRuntime for the specified logger and path
|
||||
func NewSyscallExecRuntimeWithLogger(logger *log.Logger, path string) (Runtime, error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
|
||||
}
|
||||
if info.IsDir() || info.Mode()&0111 == 0 {
|
||||
return nil, fmt.Errorf("specified path '%v' is not an executable file", path)
|
||||
}
|
||||
|
||||
shim := SyscallExecRuntime{
|
||||
logger: logger,
|
||||
path: path,
|
||||
exec: syscall.Exec,
|
||||
}
|
||||
|
||||
return &shim, nil
|
||||
}
|
||||
|
||||
// Exec exces into the binary at the path from the SyscallExecRuntime struct, passing it the supplied arguments
|
||||
// after ensuring that the first argument is the path of the target binary.
|
||||
func (s SyscallExecRuntime) Exec(args []string) error {
|
||||
runtimeArgs := []string{s.path}
|
||||
if len(args) > 1 {
|
||||
runtimeArgs = append(runtimeArgs, args[1:]...)
|
||||
}
|
||||
|
||||
err := s.exec(s.path, runtimeArgs, os.Environ())
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not exec '%v': %v", s.path, err)
|
||||
}
|
||||
|
||||
// syscall.Exec is not expected to return. This is an error state regardless of whether
|
||||
// err is nil or not.
|
||||
return fmt.Errorf("unexpected return from exec '%v'", s.path)
|
||||
}
|
100
internal/oci/runtime_exec_test.go
Normal file
100
internal/oci/runtime_exec_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSyscallExecConstructor(t *testing.T) {
|
||||
r, err := NewSyscallExecRuntime("////an/invalid/path")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewSyscallExecRuntime("/tmp")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewSyscallExecRuntime("/dev/null")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewSyscallExecRuntime("/bin/sh")
|
||||
require.NoError(t, err)
|
||||
|
||||
f, ok := r.(*SyscallExecRuntime)
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, "/bin/sh", f.path)
|
||||
}
|
||||
|
||||
func TestSyscallExecForwardsArgs(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
f := SyscallExecRuntime{
|
||||
logger: logger,
|
||||
path: "runtime",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
returnError error
|
||||
args []string
|
||||
errorPrefix string
|
||||
}{
|
||||
{
|
||||
returnError: nil,
|
||||
errorPrefix: "unexpected return from exec",
|
||||
},
|
||||
{
|
||||
returnError: fmt.Errorf("error from exec"),
|
||||
errorPrefix: "could not exec",
|
||||
},
|
||||
{
|
||||
returnError: nil,
|
||||
args: []string{"otherargv0"},
|
||||
errorPrefix: "unexpected return from exec",
|
||||
},
|
||||
{
|
||||
returnError: nil,
|
||||
args: []string{"otherargv0", "arg1", "arg2", "arg3"},
|
||||
errorPrefix: "unexpected return from exec",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
execMock := WithMockExec(f, tc.returnError)
|
||||
|
||||
err := execMock.Exec(tc.args)
|
||||
|
||||
require.Errorf(t, err, "%d: %v", i, tc)
|
||||
require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc)
|
||||
if tc.returnError != nil {
|
||||
require.Truef(t, strings.HasSuffix(err.Error(), tc.returnError.Error()), "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
require.Equalf(t, f.path, execMock.argv0, "%d: %v", i, tc)
|
||||
require.Equalf(t, f.path, execMock.argv[0], "%d: %v", i, tc)
|
||||
|
||||
require.LessOrEqualf(t, len(tc.args), len(execMock.argv), "%d: %v", i, tc)
|
||||
if len(tc.args) > 1 {
|
||||
require.Equalf(t, tc.args[1:], execMock.argv[1:], "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
49
internal/oci/runtime_mock.go
Normal file
49
internal/oci/runtime_mock.go
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
// MockExecRuntime wraps a SyscallExecRuntime, intercepting the exec call for testing
|
||||
type MockExecRuntime struct {
|
||||
SyscallExecRuntime
|
||||
execMock
|
||||
}
|
||||
|
||||
// WithMockExec wraps a specified SyscallExecRuntime with a mocked exec function for testing
|
||||
func WithMockExec(e SyscallExecRuntime, execResult error) *MockExecRuntime {
|
||||
m := MockExecRuntime{
|
||||
SyscallExecRuntime: e,
|
||||
execMock: execMock{result: execResult},
|
||||
}
|
||||
// overrdie the exec function to the mocked exec function.
|
||||
m.SyscallExecRuntime.exec = m.execMock.exec
|
||||
return &m
|
||||
}
|
||||
|
||||
type execMock struct {
|
||||
argv0 string
|
||||
argv []string
|
||||
envv []string
|
||||
result error
|
||||
}
|
||||
|
||||
func (m *execMock) exec(argv0 string, argv []string, envv []string) error {
|
||||
m.argv0 = argv0
|
||||
m.argv = argv
|
||||
m.envv = envv
|
||||
|
||||
return m.result
|
||||
}
|
102
internal/oci/spec.go
Normal file
102
internal/oci/spec.go
Normal file
@ -0,0 +1,102 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
||||
)
|
||||
|
||||
// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an
|
||||
// error. The intention is that the function would modify the spec in-place.
|
||||
type SpecModifier func(*oci.Spec) error
|
||||
|
||||
// Spec defines the operations to be performed on an OCI specification
|
||||
type Spec interface {
|
||||
Load() error
|
||||
Flush() error
|
||||
Modify(SpecModifier) error
|
||||
}
|
||||
|
||||
type fileSpec struct {
|
||||
*oci.Spec
|
||||
path string
|
||||
}
|
||||
|
||||
var _ Spec = (*fileSpec)(nil)
|
||||
|
||||
// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec.
|
||||
// This can be used to read from the file, modify the spec, and write to the
|
||||
// same file.
|
||||
func NewSpecFromFile(filepath string) Spec {
|
||||
oci := fileSpec{
|
||||
path: filepath,
|
||||
}
|
||||
|
||||
return &oci
|
||||
}
|
||||
|
||||
// Load reads the contents of an OCI spec from file to be referenced internally.
|
||||
// The file is opened "read-only"
|
||||
func (s *fileSpec) Load() error {
|
||||
specFile, err := os.Open(s.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening OCI specification file: %v", err)
|
||||
}
|
||||
defer specFile.Close()
|
||||
|
||||
decoder := json.NewDecoder(specFile)
|
||||
|
||||
var spec oci.Spec
|
||||
err = decoder.Decode(&spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading OCI specification from file: %v", err)
|
||||
}
|
||||
|
||||
s.Spec = &spec
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify applies the specified SpecModifier to the stored OCI specification.
|
||||
func (s *fileSpec) Modify(f SpecModifier) error {
|
||||
if s.Spec == nil {
|
||||
return fmt.Errorf("no spec loaded for modification")
|
||||
}
|
||||
return f(s.Spec)
|
||||
}
|
||||
|
||||
// Flush writes the stored OCI specification to the filepath specifed by the path member.
|
||||
// The file is truncated upon opening, overwriting any existing contents.
|
||||
func (s fileSpec) Flush() error {
|
||||
specFile, err := os.Create(s.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening OCI specification file: %v", err)
|
||||
}
|
||||
defer specFile.Close()
|
||||
|
||||
encoder := json.NewEncoder(specFile)
|
||||
|
||||
err = encoder.Encode(s.Spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing OCI specification to file: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
70
internal/oci/spec_mock.go
Normal file
70
internal/oci/spec_mock.go
Normal file
@ -0,0 +1,70 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
||||
)
|
||||
|
||||
// MockSpec provides a simple mock for an OCI spec to be used in testing.
|
||||
// It also implements the SpecModifier interface.
|
||||
type MockSpec struct {
|
||||
*oci.Spec
|
||||
MockLoad mockFunc
|
||||
MockFlush mockFunc
|
||||
MockModify mockFunc
|
||||
}
|
||||
|
||||
var _ Spec = (*MockSpec)(nil)
|
||||
|
||||
// NewMockSpec constructs a MockSpec to be used in testing as a Spec
|
||||
func NewMockSpec(spec *oci.Spec, flushResult error, modifyResult error) *MockSpec {
|
||||
s := MockSpec{
|
||||
Spec: spec,
|
||||
MockFlush: mockFunc{result: flushResult},
|
||||
MockModify: mockFunc{result: modifyResult},
|
||||
}
|
||||
|
||||
return &s
|
||||
}
|
||||
|
||||
// Load invokes the mocked Load function to return the predefined error / result
|
||||
func (s *MockSpec) Load() error {
|
||||
return s.MockLoad.call()
|
||||
}
|
||||
|
||||
// Flush invokes the mocked Load function to return the predefined error / result
|
||||
func (s *MockSpec) Flush() error {
|
||||
return s.MockFlush.call()
|
||||
}
|
||||
|
||||
// Modify applies the specified SpecModifier to the spec and invokes the
|
||||
// mocked modify function to return the predefined error / result.
|
||||
func (s *MockSpec) Modify(f SpecModifier) error {
|
||||
f(s.Spec)
|
||||
return s.MockModify.call()
|
||||
}
|
||||
|
||||
type mockFunc struct {
|
||||
Callcount int
|
||||
result error
|
||||
}
|
||||
|
||||
func (m *mockFunc) call() error {
|
||||
m.Callcount++
|
||||
return m.result
|
||||
}
|
2
test/bin/nvidia-container-runtime-hook
Executable file
2
test/bin/nvidia-container-runtime-hook
Executable file
@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
echo mock hook
|
2
test/bin/runc
Executable file
2
test/bin/runc
Executable file
@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
echo mock runc
|
0
test/input/nvidia-container-runtime/config.toml
Normal file
0
test/input/nvidia-container-runtime/config.toml
Normal file
178
test/input/test_spec.json
Normal file
178
test/input/test_spec.json
Normal file
@ -0,0 +1,178 @@
|
||||
{
|
||||
"ociVersion": "1.0.1-dev",
|
||||
"process": {
|
||||
"terminal": true,
|
||||
"user": {
|
||||
"uid": 0,
|
||||
"gid": 0
|
||||
},
|
||||
"args": [
|
||||
"sh"
|
||||
],
|
||||
"env": [
|
||||
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"TERM=xterm"
|
||||
],
|
||||
"cwd": "/",
|
||||
"capabilities": {
|
||||
"bounding": [
|
||||
"CAP_AUDIT_WRITE",
|
||||
"CAP_KILL",
|
||||
"CAP_NET_BIND_SERVICE"
|
||||
],
|
||||
"effective": [
|
||||
"CAP_AUDIT_WRITE",
|
||||
"CAP_KILL",
|
||||
"CAP_NET_BIND_SERVICE"
|
||||
],
|
||||
"inheritable": [
|
||||
"CAP_AUDIT_WRITE",
|
||||
"CAP_KILL",
|
||||
"CAP_NET_BIND_SERVICE"
|
||||
],
|
||||
"permitted": [
|
||||
"CAP_AUDIT_WRITE",
|
||||
"CAP_KILL",
|
||||
"CAP_NET_BIND_SERVICE"
|
||||
],
|
||||
"ambient": [
|
||||
"CAP_AUDIT_WRITE",
|
||||
"CAP_KILL",
|
||||
"CAP_NET_BIND_SERVICE"
|
||||
]
|
||||
},
|
||||
"rlimits": [
|
||||
{
|
||||
"type": "RLIMIT_NOFILE",
|
||||
"hard": 1024,
|
||||
"soft": 1024
|
||||
}
|
||||
],
|
||||
"noNewPrivileges": true
|
||||
},
|
||||
"root": {
|
||||
"path": "rootfs",
|
||||
"readonly": true
|
||||
},
|
||||
"hostname": "runc",
|
||||
"mounts": [
|
||||
{
|
||||
"destination": "/proc",
|
||||
"type": "proc",
|
||||
"source": "proc"
|
||||
},
|
||||
{
|
||||
"destination": "/dev",
|
||||
"type": "tmpfs",
|
||||
"source": "tmpfs",
|
||||
"options": [
|
||||
"nosuid",
|
||||
"strictatime",
|
||||
"mode=755",
|
||||
"size=65536k"
|
||||
]
|
||||
},
|
||||
{
|
||||
"destination": "/dev/pts",
|
||||
"type": "devpts",
|
||||
"source": "devpts",
|
||||
"options": [
|
||||
"nosuid",
|
||||
"noexec",
|
||||
"newinstance",
|
||||
"ptmxmode=0666",
|
||||
"mode=0620",
|
||||
"gid=5"
|
||||
]
|
||||
},
|
||||
{
|
||||
"destination": "/dev/shm",
|
||||
"type": "tmpfs",
|
||||
"source": "shm",
|
||||
"options": [
|
||||
"nosuid",
|
||||
"noexec",
|
||||
"nodev",
|
||||
"mode=1777",
|
||||
"size=65536k"
|
||||
]
|
||||
},
|
||||
{
|
||||
"destination": "/dev/mqueue",
|
||||
"type": "mqueue",
|
||||
"source": "mqueue",
|
||||
"options": [
|
||||
"nosuid",
|
||||
"noexec",
|
||||
"nodev"
|
||||
]
|
||||
},
|
||||
{
|
||||
"destination": "/sys",
|
||||
"type": "sysfs",
|
||||
"source": "sysfs",
|
||||
"options": [
|
||||
"nosuid",
|
||||
"noexec",
|
||||
"nodev",
|
||||
"ro"
|
||||
]
|
||||
},
|
||||
{
|
||||
"destination": "/sys/fs/cgroup",
|
||||
"type": "cgroup",
|
||||
"source": "cgroup",
|
||||
"options": [
|
||||
"nosuid",
|
||||
"noexec",
|
||||
"nodev",
|
||||
"relatime",
|
||||
"ro"
|
||||
]
|
||||
}
|
||||
],
|
||||
"linux": {
|
||||
"resources": {
|
||||
"devices": [
|
||||
{
|
||||
"allow": false,
|
||||
"access": "rwm"
|
||||
}
|
||||
]
|
||||
},
|
||||
"namespaces": [
|
||||
{
|
||||
"type": "pid"
|
||||
},
|
||||
{
|
||||
"type": "network"
|
||||
},
|
||||
{
|
||||
"type": "ipc"
|
||||
},
|
||||
{
|
||||
"type": "uts"
|
||||
},
|
||||
{
|
||||
"type": "mount"
|
||||
}
|
||||
],
|
||||
"maskedPaths": [
|
||||
"/proc/kcore",
|
||||
"/proc/latency_stats",
|
||||
"/proc/timer_list",
|
||||
"/proc/timer_stats",
|
||||
"/proc/sched_debug",
|
||||
"/sys/firmware",
|
||||
"/proc/scsi"
|
||||
],
|
||||
"readonlyPaths": [
|
||||
"/proc/asound",
|
||||
"/proc/bus",
|
||||
"/proc/fs",
|
||||
"/proc/irq",
|
||||
"/proc/sys",
|
||||
"/proc/sysrq-trigger"
|
||||
]
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user