From b6a585c77dcc7c60e939b611c493c8f0e774f68c Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 6 Sep 2021 13:26:48 +0200 Subject: [PATCH] Copy code from nvidia-container-runtime This change copies the cmd/nvidia-container-runtime, internal, and test folders from github.com/NVIDIA/nvidia-container-runtime@8a63b4b34f3ce3b4167f0516aa3f7207ca280dfb Signed-off-by: Evan Lezar --- .gitignore | 3 +- cmd/nvidia-container-runtime/logger.go | 79 +++++ cmd/nvidia-container-runtime/main.go | 89 ++++++ cmd/nvidia-container-runtime/main_test.go | 293 ++++++++++++++++++ cmd/nvidia-container-runtime/nvcr.go | 145 +++++++++ cmd/nvidia-container-runtime/nvcr_test.go | 230 ++++++++++++++ .../runtime_factory.go | 176 +++++++++++ .../runtime_factory_test.go | 192 ++++++++++++ internal/oci/runtime.go | 23 ++ internal/oci/runtime_exec.go | 79 +++++ internal/oci/runtime_exec_test.go | 100 ++++++ internal/oci/runtime_mock.go | 49 +++ internal/oci/spec.go | 102 ++++++ internal/oci/spec_mock.go | 70 +++++ test/bin/nvidia-container-runtime-hook | 2 + test/bin/runc | 2 + .../nvidia-container-runtime/config.toml | 0 test/input/test_spec.json | 178 +++++++++++ 18 files changed, 1811 insertions(+), 1 deletion(-) create mode 100644 cmd/nvidia-container-runtime/logger.go create mode 100644 cmd/nvidia-container-runtime/main.go create mode 100644 cmd/nvidia-container-runtime/main_test.go create mode 100644 cmd/nvidia-container-runtime/nvcr.go create mode 100644 cmd/nvidia-container-runtime/nvcr_test.go create mode 100644 cmd/nvidia-container-runtime/runtime_factory.go create mode 100644 cmd/nvidia-container-runtime/runtime_factory_test.go create mode 100644 internal/oci/runtime.go create mode 100644 internal/oci/runtime_exec.go create mode 100644 internal/oci/runtime_exec_test.go create mode 100644 internal/oci/runtime_mock.go create mode 100644 internal/oci/spec.go create mode 100644 internal/oci/spec_mock.go create mode 100755 test/bin/nvidia-container-runtime-hook create mode 100755 test/bin/runc create mode 100644 test/input/nvidia-container-runtime/config.toml create mode 100644 test/input/test_spec.json diff --git a/.gitignore b/.gitignore index f766658d..a38c2b0c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ dist *.swp *.swo -/coverage.out \ No newline at end of file +/coverage.out +/test/output/ diff --git a/cmd/nvidia-container-runtime/logger.go b/cmd/nvidia-container-runtime/logger.go new file mode 100644 index 00000000..803a1d35 --- /dev/null +++ b/cmd/nvidia-container-runtime/logger.go @@ -0,0 +1,79 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package main + +import ( + "fmt" + "io" + "os" + + "github.com/sirupsen/logrus" + "github.com/tsaikd/KDGoLib/logrusutil" +) + +// Logger adds a way to manage output to a log file to a logrus.Logger +type Logger struct { + *logrus.Logger + previousOutput io.Writer + logFile *os.File +} + +// NewLogger constructs a Logger with a preddefined formatter +func NewLogger() *Logger { + logrusLogger := logrus.New() + + formatter := &logrusutil.ConsoleLogFormatter{ + TimestampFormat: "2006/01/02 15:04:07", + Flag: logrusutil.Ltime, + } + + logger := &Logger{ + Logger: logrusLogger, + } + logger.SetFormatter(formatter) + + return logger +} + +// LogToFile opens the specified file for appending and sets the logger to +// output to the opened file. A reference to the file pointer is stored to +// allow this to be closed. +func (l *Logger) LogToFile(filename string) error { + logFile, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("error opening debug log file: %v", err) + } + + l.logFile = logFile + l.previousOutput = l.Out + l.SetOutput(logFile) + + return nil +} + +// CloseFile closes the log file (if any) and resets the logger output to what it +// was before LogToFile was called. +func (l *Logger) CloseFile() error { + if l.logFile == nil { + return nil + } + logFile := l.logFile + l.SetOutput(l.previousOutput) + l.logFile = nil + + return logFile.Close() +} diff --git a/cmd/nvidia-container-runtime/main.go b/cmd/nvidia-container-runtime/main.go new file mode 100644 index 00000000..cbb52b71 --- /dev/null +++ b/cmd/nvidia-container-runtime/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "fmt" + "os" + "path" + + "github.com/pelletier/go-toml" +) + +const ( + configOverride = "XDG_CONFIG_HOME" + configFilePath = "nvidia-container-runtime/config.toml" + + hookDefaultFilePath = "/usr/bin/nvidia-container-runtime-hook" +) + +var ( + configDir = "/etc/" +) + +var logger = NewLogger() + +func main() { + err := run(os.Args) + if err != nil { + logger.Errorf("Error running %v: %v", os.Args, err) + os.Exit(1) + } +} + +// run is an entry point that allows for idiomatic handling of errors +// when calling from the main function. +func run(argv []string) (err error) { + cfg, err := getConfig() + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + + err = logger.LogToFile(cfg.debugFilePath) + if err != nil { + return fmt.Errorf("error opening debug log file: %v", err) + } + defer func() { + // We capture and log a returning error before closing the log file. + if err != nil { + logger.Errorf("Error running %v: %v", argv, err) + } + logger.CloseFile() + }() + + r, err := newRuntime(argv) + if err != nil { + return fmt.Errorf("error creating runtime: %v", err) + } + + logger.Printf("Running %s\n", argv[0]) + return r.Exec(argv) +} + +type config struct { + debugFilePath string +} + +// getConfig sets up the config struct. Values are read from a toml file +// or set via the environment. +func getConfig() (*config, error) { + cfg := &config{} + + if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 { + configDir = XDGConfigDir + } + + configFilePath := path.Join(configDir, configFilePath) + + tomlContent, err := os.ReadFile(configFilePath) + if err != nil { + return nil, err + } + + toml, err := toml.Load(string(tomlContent)) + if err != nil { + return nil, err + } + + cfg.debugFilePath = toml.GetDefault("nvidia-container-runtime.debug", "/dev/null").(string) + + return cfg, nil +} diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go new file mode 100644 index 00000000..9ef83b9f --- /dev/null +++ b/cmd/nvidia-container-runtime/main_test.go @@ -0,0 +1,293 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/require" +) + +const ( + nvidiaRuntime = "nvidia-container-runtime" + nvidiaHook = "nvidia-container-runtime-hook" + bundlePathSuffix = "test/output/bundle/" + specFile = "config.json" + unmodifiedSpecFileSuffix = "test/input/test_spec.json" +) + +type testConfig struct { + root string + binPath string +} + +var cfg *testConfig + +func TestMain(m *testing.M) { + // TEST SETUP + // Determine the module root and the test binary path + var err error + moduleRoot, err := getModuleRoot() + if err != nil { + logger.Fatalf("error in test setup: could not get module root: %v", err) + } + testBinPath := filepath.Join(moduleRoot, "test", "bin") + testInputPath := filepath.Join(moduleRoot, "test", "input") + + // Set the environment variables for the test + os.Setenv("PATH", prependToPath(testBinPath, moduleRoot)) + os.Setenv("XDG_CONFIG_HOME", testInputPath) + + // Confirm that the environment is configured correctly + runcPath, err := exec.LookPath(runcExecutableName) + if err != nil || filepath.Join(testBinPath, runcExecutableName) != runcPath { + logger.Fatalf("error in test setup: mock runc path set incorrectly in TestMain(): %v", err) + } + hookPath, err := exec.LookPath(nvidiaHook) + if err != nil || filepath.Join(testBinPath, nvidiaHook) != hookPath { + logger.Fatalf("error in test setup: mock hook path set incorrectly in TestMain(): %v", err) + } + + // Store the root and binary paths in the test Config + cfg = &testConfig{ + root: moduleRoot, + binPath: testBinPath, + } + + // RUN TESTS + exitCode := m.Run() + + // TEST CLEANUP + os.Remove(specFile) + + os.Exit(exitCode) +} + +func getModuleRoot() (string, error) { + _, filename, _, _ := runtime.Caller(0) + + return hasGoMod(filename) +} + +func hasGoMod(dir string) (string, error) { + if dir == "" || dir == "/" { + return "", fmt.Errorf("module root not found") + } + + _, err := os.Stat(filepath.Join(dir, "go.mod")) + if err != nil { + return hasGoMod(filepath.Dir(dir)) + } + return dir, nil +} + +func prependToPath(additionalPaths ...string) string { + paths := strings.Split(os.Getenv("PATH"), ":") + paths = append(additionalPaths, paths...) + + return strings.Join(paths, ":") +} + +// case 1) nvidia-container-runtime run --bundle +// case 2) nvidia-container-runtime create --bundle +// - Confirm the runtime handles bad input correctly +func TestBadInput(t *testing.T) { + err := cfg.generateNewRuntimeSpec() + if err != nil { + t.Fatal(err) + } + + cmdRun := exec.Command(nvidiaRuntime, "run", "--bundle") + t.Logf("executing: %s\n", strings.Join(cmdRun.Args, " ")) + output, err := cmdRun.CombinedOutput() + require.Errorf(t, err, "runtime should return an error", "output=%v", string(output)) + + cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle") + t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " ")) + err = cmdCreate.Run() + require.Error(t, err, "runtime should return an error") +} + +// case 1) nvidia-container-runtime run --bundle +// - Confirm the runtime runs with no errors +// case 2) nvidia-container-runtime create --bundle +// - Confirm the runtime inserts the NVIDIA prestart hook correctly +func TestGoodInput(t *testing.T) { + err := cfg.generateNewRuntimeSpec() + if err != nil { + t.Fatalf("error generating runtime spec: %v", err) + } + + cmdRun := exec.Command(nvidiaRuntime, "run", "--bundle", cfg.bundlePath(), "testcontainer") + t.Logf("executing: %s\n", strings.Join(cmdRun.Args, " ")) + output, err := cmdRun.CombinedOutput() + require.NoErrorf(t, err, "runtime should not return an error", "output=%v", string(output)) + + // Check config.json and confirm there are no hooks + spec, err := cfg.getRuntimeSpec() + require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") + require.Empty(t, spec.Hooks, "there should be no hooks in config.json") + + cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer") + t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " ")) + err = cmdCreate.Run() + require.NoError(t, err, "runtime should not return an error") + + // Check config.json for NVIDIA prestart hook + spec, err = cfg.getRuntimeSpec() + require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") + require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json") + require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") +} + +// NVIDIA prestart hook already present in config file +func TestDuplicateHook(t *testing.T) { + err := cfg.generateNewRuntimeSpec() + if err != nil { + t.Fatal(err) + } + + var spec specs.Spec + spec, err = cfg.getRuntimeSpec() + if err != nil { + t.Fatal(err) + } + + t.Logf("inserting nvidia prestart hook to config.json") + if err = addNVIDIAHook(&spec); err != nil { + t.Fatal(err) + } + + jsonOutput, err := json.MarshalIndent(spec, "", "\t") + if err != nil { + t.Fatal(err) + } + + jsonFile, err := os.OpenFile(cfg.specFilePath(), os.O_RDWR, 0644) + if err != nil { + t.Fatal(err) + } + _, err = jsonFile.WriteAt(jsonOutput, 0) + if err != nil { + t.Fatal(err) + } + + // Test how runtime handles already existing prestart hook in config.json + cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle", cfg.bundlePath(), "testcontainer") + t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " ")) + output, err := cmdCreate.CombinedOutput() + require.NoErrorf(t, err, "runtime should not return an error", "output=%v", string(output)) + + // Check config.json for NVIDIA prestart hook + spec, err = cfg.getRuntimeSpec() + require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") + require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json") + require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") +} + +// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for +// testing. +func addNVIDIAHook(spec *specs.Spec) error { + r := nvidiaContainerRuntime{logger: logger.Logger} + return r.addNVIDIAHook(spec) +} + +func (c testConfig) getRuntimeSpec() (specs.Spec, error) { + filePath := c.specFilePath() + + var spec specs.Spec + jsonFile, err := os.OpenFile(filePath, os.O_RDWR, 0644) + if err != nil { + return spec, err + } + defer jsonFile.Close() + + jsonContent, err := ioutil.ReadAll(jsonFile) + if err != nil { + return spec, err + } else if json.Valid(jsonContent) { + err = json.Unmarshal(jsonContent, &spec) + if err != nil { + return spec, err + } + } else { + err = json.NewDecoder(bytes.NewReader(jsonContent)).Decode(&spec) + if err != nil { + return spec, err + } + } + + return spec, err +} + +func (c testConfig) bundlePath() string { + return filepath.Join(c.root, bundlePathSuffix) +} + +func (c testConfig) specFilePath() string { + return filepath.Join(c.bundlePath(), specFile) +} + +func (c testConfig) unmodifiedSpecFile() string { + return filepath.Join(c.root, unmodifiedSpecFileSuffix) +} + +func (c testConfig) generateNewRuntimeSpec() error { + var err error + + err = os.MkdirAll(c.bundlePath(), 0755) + if err != nil { + return err + } + + cmd := exec.Command("cp", c.unmodifiedSpecFile(), c.specFilePath()) + err = cmd.Run() + if err != nil { + return err + } + return nil +} + +// Return number of valid NVIDIA prestart hooks in runtime spec +func nvidiaHookCount(hooks *specs.Hooks) int { + if hooks == nil { + return 0 + } + + count := 0 + for _, hook := range hooks.Prestart { + if strings.Contains(hook.Path, nvidiaHook) { + count++ + } + } + return count +} + +func TestGetConfigWithCustomConfig(t *testing.T) { + wd, err := os.Getwd() + require.NoError(t, err) + + // By default debug is disabled + contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"") + testDir := filepath.Join(wd, "test") + filename := filepath.Join(testDir, configFilePath) + + os.Setenv(configOverride, testDir) + + require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766)) + require.NoError(t, ioutil.WriteFile(filename, contents, 0766)) + + defer func() { require.NoError(t, os.RemoveAll(testDir)) }() + + cfg, err := getConfig() + require.NoError(t, err) + require.Equal(t, cfg.debugFilePath, "/nvidia-container-toolkit.log") +} diff --git a/cmd/nvidia-container-runtime/nvcr.go b/cmd/nvidia-container-runtime/nvcr.go new file mode 100644 index 00000000..91e0b3a2 --- /dev/null +++ b/cmd/nvidia-container-runtime/nvcr.go @@ -0,0 +1,145 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package main + +import ( + "fmt" + "os" + "os/exec" + "strings" + + "github.com/NVIDIA/nvidia-container-runtime/internal/oci" + "github.com/opencontainers/runtime-spec/specs-go" + log "github.com/sirupsen/logrus" +) + +// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally +// modifying the specified OCI specification before invoking the runtime. +type nvidiaContainerRuntime struct { + logger *log.Logger + runtime oci.Runtime + ociSpec oci.Spec +} + +var _ oci.Runtime = (*nvidiaContainerRuntime)(nil) + +// newNvidiaContainerRuntime is a constructor for a standard runtime shim. +func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) { + r := nvidiaContainerRuntime{ + logger: logger, + runtime: runtime, + ociSpec: ociSpec, + } + + return &r, nil +} + +// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications +// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then +// forwarded to the underlying runtime's Exec method. +func (r nvidiaContainerRuntime) Exec(args []string) error { + if r.modificationRequired(args) { + err := r.modifyOCISpec() + if err != nil { + return fmt.Errorf("error modifying OCI spec: %v", err) + } + } + + r.logger.Println("Forwarding command to runtime") + return r.runtime.Exec(args) +} + +// modificationRequired checks the intput arguments to determine whether a modification +// to the OCI spec is required. +func (r nvidiaContainerRuntime) modificationRequired(args []string) bool { + var previousWasBundle bool + for _, a := range args { + // We check for '--bundle create' explicitly to ensure that we + // don't inadvertently trigger a modification if the bundle directory + // is specified as `create` + if !previousWasBundle && isBundleFlag(a) { + previousWasBundle = true + continue + } + + if !previousWasBundle && a == "create" { + r.logger.Infof("'create' command detected; modification required") + return true + } + + previousWasBundle = false + } + + r.logger.Infof("No modification required") + return false +} + +// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime +// struct. The spec is modified in-place and written to the same file as the input after +// modifcationas are applied. +func (r nvidiaContainerRuntime) modifyOCISpec() error { + err := r.ociSpec.Load() + if err != nil { + return fmt.Errorf("error loading OCI specification for modification: %v", err) + } + + err = r.ociSpec.Modify(r.addNVIDIAHook) + if err != nil { + return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err) + } + + err = r.ociSpec.Flush() + if err != nil { + return fmt.Errorf("error writing modified OCI specification: %v", err) + } + return nil +} + +// addNVIDIAHook modifies the specified OCI specification in-place, inserting a +// prestart hook. +func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error { + path, err := exec.LookPath("nvidia-container-runtime-hook") + if err != nil { + path = hookDefaultFilePath + _, err = os.Stat(path) + if err != nil { + return err + } + } + + r.logger.Printf("prestart hook path: %s\n", path) + + args := []string{path} + if spec.Hooks == nil { + spec.Hooks = &specs.Hooks{} + } else if len(spec.Hooks.Prestart) != 0 { + for _, hook := range spec.Hooks.Prestart { + if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") { + continue + } + r.logger.Println("existing nvidia prestart hook in OCI spec file") + return nil + } + } + + spec.Hooks.Prestart = append(spec.Hooks.Prestart, specs.Hook{ + Path: path, + Args: append(args, "prestart"), + }) + + return nil +} diff --git a/cmd/nvidia-container-runtime/nvcr_test.go b/cmd/nvidia-container-runtime/nvcr_test.go new file mode 100644 index 00000000..06a97630 --- /dev/null +++ b/cmd/nvidia-container-runtime/nvcr_test.go @@ -0,0 +1,230 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package main + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/NVIDIA/nvidia-container-runtime/internal/oci" + "github.com/opencontainers/runtime-spec/specs-go" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestArgsGetConfigFilePath(t *testing.T) { + wd, err := os.Getwd() + require.NoError(t, err) + + testCases := []struct { + bundleDir string + ociSpecPath string + }{ + { + ociSpecPath: fmt.Sprintf("%v/config.json", wd), + }, + { + bundleDir: "/foo/bar", + ociSpecPath: "/foo/bar/config.json", + }, + { + bundleDir: "/foo/bar/", + ociSpecPath: "/foo/bar/config.json", + }, + } + + for i, tc := range testCases { + cp, err := getOCISpecFilePath(tc.bundleDir) + + require.NoErrorf(t, err, "%d: %v", i, tc) + require.Equalf(t, tc.ociSpecPath, cp, "%d: %v", i, tc) + } +} + +func TestAddNvidiaHook(t *testing.T) { + logger, logHook := testlog.NewNullLogger() + shim := nvidiaContainerRuntime{ + logger: logger, + } + + testCases := []struct { + spec *specs.Spec + errorPrefix string + shouldNotAdd bool + }{ + { + spec: &specs.Spec{}, + }, + { + spec: &specs.Spec{ + Hooks: &specs.Hooks{}, + }, + }, + { + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{{ + Path: "some-hook", + }}, + }, + }, + }, + { + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{{ + Path: "nvidia-container-runtime-hook", + }}, + }, + }, + shouldNotAdd: true, + }, + } + + for i, tc := range testCases { + logHook.Reset() + + var numPrestartHooks int + if tc.spec.Hooks != nil { + numPrestartHooks = len(tc.spec.Hooks.Prestart) + } + + err := shim.addNVIDIAHook(tc.spec) + + if tc.errorPrefix == "" { + require.NoErrorf(t, err, "%d: %v", i, tc) + } else { + require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc) + + require.NotNilf(t, tc.spec.Hooks, "%d: %v", i, tc) + require.Equalf(t, 1, nvidiaHookCount(tc.spec.Hooks), "%d: %v", i, tc) + + if tc.shouldNotAdd { + require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc) + } else { + require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc) + + nvidiaHook := tc.spec.Hooks.Poststart[len(tc.spec.Hooks.Poststart)-1] + + // TODO: This assumes that the hook has been set up in the makefile + expectedPath := "/usr/bin/nvidia-container-runtime-hook" + require.Equalf(t, expectedPath, nvidiaHook.Path, "%d: %v", i, tc) + require.Equalf(t, []string{expectedPath, "prestart"}, nvidiaHook.Args, "%d: %v", i, tc) + require.Emptyf(t, nvidiaHook.Env, "%d: %v", i, tc) + require.Nilf(t, nvidiaHook.Timeout, "%d: %v", i, tc) + } + } + } +} + +func TestNvidiaContainerRuntime(t *testing.T) { + logger, hook := testlog.NewNullLogger() + + testCases := []struct { + shim nvidiaContainerRuntime + shouldModify bool + args []string + modifyError error + writeError error + }{ + { + shim: nvidiaContainerRuntime{}, + shouldModify: false, + }, + { + shim: nvidiaContainerRuntime{}, + args: []string{"create"}, + shouldModify: true, + }, + { + shim: nvidiaContainerRuntime{}, + args: []string{"--bundle=create"}, + shouldModify: false, + }, + { + shim: nvidiaContainerRuntime{}, + args: []string{"--bundle", "create"}, + shouldModify: false, + }, + { + shim: nvidiaContainerRuntime{}, + args: []string{"create"}, + shouldModify: true, + }, + { + shim: nvidiaContainerRuntime{}, + args: []string{"create"}, + modifyError: fmt.Errorf("error modifying"), + shouldModify: true, + }, + { + shim: nvidiaContainerRuntime{}, + args: []string{"create"}, + writeError: fmt.Errorf("error writing"), + shouldModify: true, + }, + } + + for i, tc := range testCases { + tc.shim.logger = logger + hook.Reset() + + spec := &specs.Spec{} + ociMock := oci.NewMockSpec(spec, tc.writeError, tc.modifyError) + + require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc) + + tc.shim.ociSpec = ociMock + tc.shim.runtime = &MockShim{} + + err := tc.shim.Exec(tc.args) + if tc.modifyError != nil || tc.writeError != nil { + require.Error(t, err, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + } + + if tc.shouldModify { + require.Equal(t, 1, ociMock.MockModify.Callcount, "%d: %v", i, tc) + require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "%d: %v", i, tc) + } else { + require.Equal(t, 0, ociMock.MockModify.Callcount, "%d: %v", i, tc) + require.Nil(t, spec.Hooks, "%d: %v", i, tc) + } + + writeExpected := tc.shouldModify && tc.modifyError == nil + if writeExpected { + require.Equal(t, 1, ociMock.MockFlush.Callcount, "%d: %v", i, tc) + } else { + require.Equal(t, 0, ociMock.MockFlush.Callcount, "%d: %v", i, tc) + } + } +} + +type MockShim struct { + called bool + args []string + returnError error +} + +func (m *MockShim) Exec(args []string) error { + m.called = true + m.args = args + return m.returnError +} diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go new file mode 100644 index 00000000..ff86e3a8 --- /dev/null +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -0,0 +1,176 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package main + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/NVIDIA/nvidia-container-runtime/internal/oci" +) + +const ( + ociSpecFileName = "config.json" + dockerRuncExecutableName = "docker-runc" + runcExecutableName = "runc" +) + +// newRuntime is a factory method that constructs a runtime based on the selected configuration. +func newRuntime(argv []string) (oci.Runtime, error) { + ociSpec, err := newOCISpec(argv) + if err != nil { + return nil, fmt.Errorf("error constructing OCI specification: %v", err) + } + + runc, err := newRuncRuntime() + if err != nil { + return nil, fmt.Errorf("error constructing runc runtime: %v", err) + } + + r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec) + if err != nil { + return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err) + } + + return r, nil +} + +// newOCISpec constructs an OCI spec for the provided arguments +func newOCISpec(argv []string) (oci.Spec, error) { + bundlePath, err := getBundlePath(argv) + if err != nil { + return nil, fmt.Errorf("error parsing command line arguments: %v", err) + } + + ociSpecPath, err := getOCISpecFilePath(bundlePath) + if err != nil { + return nil, fmt.Errorf("error getting OCI specification file path: %v", err) + } + ociSpec := oci.NewSpecFromFile(ociSpecPath) + + return ociSpec, nil +} + +// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime +func newRuncRuntime() (oci.Runtime, error) { + runtimePath, err := findRunc() + if err != nil { + return nil, fmt.Errorf("error locating runtime: %v", err) + } + + runc, err := oci.NewSyscallExecRuntimeWithLogger(logger.Logger, runtimePath) + if err != nil { + return nil, fmt.Errorf("error constructing runtime: %v", err) + } + + return runc, nil +} + +// getBundlePath checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc. +// The following are supported: +// --bundle{{SEP}}BUNDLE_PATH +// -bundle{{SEP}}BUNDLE_PATH +// -b{{SEP}}BUNDLE_PATH +// where {{SEP}} is either ' ' or '=' +func getBundlePath(argv []string) (string, error) { + var bundlePath string + + for i := 0; i < len(argv); i++ { + param := argv[i] + + parts := strings.SplitN(param, "=", 2) + if !isBundleFlag(parts[0]) { + continue + } + + // The flag has the format --bundle=/path + if len(parts) == 2 { + bundlePath = parts[1] + continue + } + + // The flag has the format --bundle /path + if i+1 < len(argv) { + bundlePath = argv[i+1] + i++ + continue + } + + // --bundle / -b was the last element of argv + return "", fmt.Errorf("bundle option requires an argument") + } + + return bundlePath, nil +} + +// findRunc locates runc in the path, returning the full path to the +// binary or an error. +func findRunc() (string, error) { + runtimeCandidates := []string{ + dockerRuncExecutableName, + runcExecutableName, + } + + return findRuntime(runtimeCandidates) +} + +func findRuntime(runtimeCandidates []string) (string, error) { + for _, candidate := range runtimeCandidates { + logger.Infof("Looking for runtime binary '%v'", candidate) + runcPath, err := exec.LookPath(candidate) + if err == nil { + logger.Infof("Found runtime binary '%v'", runcPath) + return runcPath, nil + } + logger.Warnf("Runtime binary '%v' not found: %v", candidate, err) + } + + return "", fmt.Errorf("no runtime binary found from candidate list: %v", runtimeCandidates) +} + +func isBundleFlag(arg string) bool { + if !strings.HasPrefix(arg, "-") { + return false + } + + trimmed := strings.TrimLeft(arg, "-") + return trimmed == "b" || trimmed == "bundle" +} + +// getOCISpecFilePath returns the expected path to the OCI specification file for the given +// bundle directory or the current working directory if not specified. +func getOCISpecFilePath(bundleDir string) (string, error) { + if bundleDir == "" { + logger.Infof("Bundle directory path is empty, using working directory.") + workingDirectory, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("error getting working directory: %v", err) + } + bundleDir = workingDirectory + } + + logger.Infof("Using bundle directory: %v", bundleDir) + + OCISpecFilePath := filepath.Join(bundleDir, ociSpecFileName) + + logger.Infof("Using OCI specification file path: %v", OCISpecFilePath) + + return OCISpecFilePath, nil +} diff --git a/cmd/nvidia-container-runtime/runtime_factory_test.go b/cmd/nvidia-container-runtime/runtime_factory_test.go new file mode 100644 index 00000000..b5a6d461 --- /dev/null +++ b/cmd/nvidia-container-runtime/runtime_factory_test.go @@ -0,0 +1,192 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package main + +import ( + "path/filepath" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestConstructor(t *testing.T) { + shim, err := newRuntime([]string{}) + + require.NoError(t, err) + require.NotNil(t, shim) +} + +func TestGetBundlePath(t *testing.T) { + type expected struct { + bundle string + isError bool + } + testCases := []struct { + argv []string + expected expected + }{ + { + argv: []string{}, + }, + { + argv: []string{"create"}, + }, + { + argv: []string{"--bundle"}, + expected: expected{ + isError: true, + }, + }, + { + argv: []string{"-b"}, + expected: expected{ + isError: true, + }, + }, + { + argv: []string{"--bundle", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"--not-bundle", "/foo/bar"}, + }, + { + argv: []string{"--"}, + }, + { + argv: []string{"-bundle", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"--bundle=/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b=/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b=/foo/=bar"}, + expected: expected{ + bundle: "/foo/=bar", + }, + }, + { + argv: []string{"-b", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"create", "-b", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b", "create", "create"}, + expected: expected{ + bundle: "create", + }, + }, + { + argv: []string{"-b=create", "create"}, + expected: expected{ + bundle: "create", + }, + }, + { + argv: []string{"-b", "create"}, + expected: expected{ + bundle: "create", + }, + }, + } + + for i, tc := range testCases { + bundle, err := getBundlePath(tc.argv) + + if tc.expected.isError { + require.Errorf(t, err, "%d: %v", i, tc) + } else { + require.NoErrorf(t, err, "%d: %v", i, tc) + } + + require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc) + } +} + +func TestFindRunc(t *testing.T) { + testLogger, _ := testlog.NewNullLogger() + logger.Logger = testLogger + + runcPath, err := findRunc() + require.NoError(t, err) + require.Equal(t, filepath.Join(cfg.binPath, runcExecutableName), runcPath) +} + +func TestFindRuntime(t *testing.T) { + testLogger, _ := testlog.NewNullLogger() + logger.Logger = testLogger + + testCases := []struct { + candidates []string + expectedPath string + }{ + { + candidates: []string{}, + }, + { + candidates: []string{"not-runc"}, + }, + { + candidates: []string{"not-runc", "also-not-runc"}, + }, + { + candidates: []string{runcExecutableName}, + expectedPath: filepath.Join(cfg.binPath, runcExecutableName), + }, + { + candidates: []string{runcExecutableName, "not-runc"}, + expectedPath: filepath.Join(cfg.binPath, runcExecutableName), + }, + { + candidates: []string{"not-runc", runcExecutableName}, + expectedPath: filepath.Join(cfg.binPath, runcExecutableName), + }, + } + + for i, tc := range testCases { + runcPath, err := findRuntime(tc.candidates) + if tc.expectedPath == "" { + require.Error(t, err, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + } + require.Equal(t, tc.expectedPath, runcPath, "%d: %v", i, tc) + } + +} diff --git a/internal/oci/runtime.go b/internal/oci/runtime.go new file mode 100644 index 00000000..89df5aa1 --- /dev/null +++ b/internal/oci/runtime.go @@ -0,0 +1,23 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +// Runtime is an interface for a runtime shim. The Exec method accepts a list +// of command line arguments, and returns an error / nil. +type Runtime interface { + Exec([]string) error +} diff --git a/internal/oci/runtime_exec.go b/internal/oci/runtime_exec.go new file mode 100644 index 00000000..98415747 --- /dev/null +++ b/internal/oci/runtime_exec.go @@ -0,0 +1,79 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "os" + "syscall" + + log "github.com/sirupsen/logrus" +) + +// SyscallExecRuntime wraps the path that a binary and defines the semanitcs for how to exec into it. +// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the +// Runtime internface. +type SyscallExecRuntime struct { + logger *log.Logger + path string + // exec is used for testing. This defaults to syscall.Exec + exec func(argv0 string, argv []string, envv []string) error +} + +var _ Runtime = (*SyscallExecRuntime)(nil) + +// NewSyscallExecRuntime creates a SyscallExecRuntime for the specified path with the standard logger +func NewSyscallExecRuntime(path string) (Runtime, error) { + return NewSyscallExecRuntimeWithLogger(log.StandardLogger(), path) +} + +// NewSyscallExecRuntimeWithLogger creates a SyscallExecRuntime for the specified logger and path +func NewSyscallExecRuntimeWithLogger(logger *log.Logger, path string) (Runtime, error) { + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("invalid path '%v': %v", path, err) + } + if info.IsDir() || info.Mode()&0111 == 0 { + return nil, fmt.Errorf("specified path '%v' is not an executable file", path) + } + + shim := SyscallExecRuntime{ + logger: logger, + path: path, + exec: syscall.Exec, + } + + return &shim, nil +} + +// Exec exces into the binary at the path from the SyscallExecRuntime struct, passing it the supplied arguments +// after ensuring that the first argument is the path of the target binary. +func (s SyscallExecRuntime) Exec(args []string) error { + runtimeArgs := []string{s.path} + if len(args) > 1 { + runtimeArgs = append(runtimeArgs, args[1:]...) + } + + err := s.exec(s.path, runtimeArgs, os.Environ()) + if err != nil { + return fmt.Errorf("could not exec '%v': %v", s.path, err) + } + + // syscall.Exec is not expected to return. This is an error state regardless of whether + // err is nil or not. + return fmt.Errorf("unexpected return from exec '%v'", s.path) +} diff --git a/internal/oci/runtime_exec_test.go b/internal/oci/runtime_exec_test.go new file mode 100644 index 00000000..83ac64a2 --- /dev/null +++ b/internal/oci/runtime_exec_test.go @@ -0,0 +1,100 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ +package oci + +import ( + "fmt" + "strings" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestSyscallExecConstructor(t *testing.T) { + r, err := NewSyscallExecRuntime("////an/invalid/path") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewSyscallExecRuntime("/tmp") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewSyscallExecRuntime("/dev/null") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewSyscallExecRuntime("/bin/sh") + require.NoError(t, err) + + f, ok := r.(*SyscallExecRuntime) + require.True(t, ok) + + require.Equal(t, "/bin/sh", f.path) +} + +func TestSyscallExecForwardsArgs(t *testing.T) { + logger, _ := testlog.NewNullLogger() + f := SyscallExecRuntime{ + logger: logger, + path: "runtime", + } + + testCases := []struct { + returnError error + args []string + errorPrefix string + }{ + { + returnError: nil, + errorPrefix: "unexpected return from exec", + }, + { + returnError: fmt.Errorf("error from exec"), + errorPrefix: "could not exec", + }, + { + returnError: nil, + args: []string{"otherargv0"}, + errorPrefix: "unexpected return from exec", + }, + { + returnError: nil, + args: []string{"otherargv0", "arg1", "arg2", "arg3"}, + errorPrefix: "unexpected return from exec", + }, + } + + for i, tc := range testCases { + execMock := WithMockExec(f, tc.returnError) + + err := execMock.Exec(tc.args) + + require.Errorf(t, err, "%d: %v", i, tc) + require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc) + if tc.returnError != nil { + require.Truef(t, strings.HasSuffix(err.Error(), tc.returnError.Error()), "%d: %v", i, tc) + } + + require.Equalf(t, f.path, execMock.argv0, "%d: %v", i, tc) + require.Equalf(t, f.path, execMock.argv[0], "%d: %v", i, tc) + + require.LessOrEqualf(t, len(tc.args), len(execMock.argv), "%d: %v", i, tc) + if len(tc.args) > 1 { + require.Equalf(t, tc.args[1:], execMock.argv[1:], "%d: %v", i, tc) + } + } +} diff --git a/internal/oci/runtime_mock.go b/internal/oci/runtime_mock.go new file mode 100644 index 00000000..e09cfb79 --- /dev/null +++ b/internal/oci/runtime_mock.go @@ -0,0 +1,49 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +// MockExecRuntime wraps a SyscallExecRuntime, intercepting the exec call for testing +type MockExecRuntime struct { + SyscallExecRuntime + execMock +} + +// WithMockExec wraps a specified SyscallExecRuntime with a mocked exec function for testing +func WithMockExec(e SyscallExecRuntime, execResult error) *MockExecRuntime { + m := MockExecRuntime{ + SyscallExecRuntime: e, + execMock: execMock{result: execResult}, + } + // overrdie the exec function to the mocked exec function. + m.SyscallExecRuntime.exec = m.execMock.exec + return &m +} + +type execMock struct { + argv0 string + argv []string + envv []string + result error +} + +func (m *execMock) exec(argv0 string, argv []string, envv []string) error { + m.argv0 = argv0 + m.argv = argv + m.envv = envv + + return m.result +} diff --git a/internal/oci/spec.go b/internal/oci/spec.go new file mode 100644 index 00000000..0c3aaf4d --- /dev/null +++ b/internal/oci/spec.go @@ -0,0 +1,102 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "encoding/json" + "fmt" + "os" + + oci "github.com/opencontainers/runtime-spec/specs-go" +) + +// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an +// error. The intention is that the function would modify the spec in-place. +type SpecModifier func(*oci.Spec) error + +// Spec defines the operations to be performed on an OCI specification +type Spec interface { + Load() error + Flush() error + Modify(SpecModifier) error +} + +type fileSpec struct { + *oci.Spec + path string +} + +var _ Spec = (*fileSpec)(nil) + +// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec. +// This can be used to read from the file, modify the spec, and write to the +// same file. +func NewSpecFromFile(filepath string) Spec { + oci := fileSpec{ + path: filepath, + } + + return &oci +} + +// Load reads the contents of an OCI spec from file to be referenced internally. +// The file is opened "read-only" +func (s *fileSpec) Load() error { + specFile, err := os.Open(s.path) + if err != nil { + return fmt.Errorf("error opening OCI specification file: %v", err) + } + defer specFile.Close() + + decoder := json.NewDecoder(specFile) + + var spec oci.Spec + err = decoder.Decode(&spec) + if err != nil { + return fmt.Errorf("error reading OCI specification from file: %v", err) + } + + s.Spec = &spec + return nil +} + +// Modify applies the specified SpecModifier to the stored OCI specification. +func (s *fileSpec) Modify(f SpecModifier) error { + if s.Spec == nil { + return fmt.Errorf("no spec loaded for modification") + } + return f(s.Spec) +} + +// Flush writes the stored OCI specification to the filepath specifed by the path member. +// The file is truncated upon opening, overwriting any existing contents. +func (s fileSpec) Flush() error { + specFile, err := os.Create(s.path) + if err != nil { + return fmt.Errorf("error opening OCI specification file: %v", err) + } + defer specFile.Close() + + encoder := json.NewEncoder(specFile) + + err = encoder.Encode(s.Spec) + if err != nil { + return fmt.Errorf("error writing OCI specification to file: %v", err) + } + + return nil +} diff --git a/internal/oci/spec_mock.go b/internal/oci/spec_mock.go new file mode 100644 index 00000000..1247adaf --- /dev/null +++ b/internal/oci/spec_mock.go @@ -0,0 +1,70 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + oci "github.com/opencontainers/runtime-spec/specs-go" +) + +// MockSpec provides a simple mock for an OCI spec to be used in testing. +// It also implements the SpecModifier interface. +type MockSpec struct { + *oci.Spec + MockLoad mockFunc + MockFlush mockFunc + MockModify mockFunc +} + +var _ Spec = (*MockSpec)(nil) + +// NewMockSpec constructs a MockSpec to be used in testing as a Spec +func NewMockSpec(spec *oci.Spec, flushResult error, modifyResult error) *MockSpec { + s := MockSpec{ + Spec: spec, + MockFlush: mockFunc{result: flushResult}, + MockModify: mockFunc{result: modifyResult}, + } + + return &s +} + +// Load invokes the mocked Load function to return the predefined error / result +func (s *MockSpec) Load() error { + return s.MockLoad.call() +} + +// Flush invokes the mocked Load function to return the predefined error / result +func (s *MockSpec) Flush() error { + return s.MockFlush.call() +} + +// Modify applies the specified SpecModifier to the spec and invokes the +// mocked modify function to return the predefined error / result. +func (s *MockSpec) Modify(f SpecModifier) error { + f(s.Spec) + return s.MockModify.call() +} + +type mockFunc struct { + Callcount int + result error +} + +func (m *mockFunc) call() error { + m.Callcount++ + return m.result +} diff --git a/test/bin/nvidia-container-runtime-hook b/test/bin/nvidia-container-runtime-hook new file mode 100755 index 00000000..5579c7d1 --- /dev/null +++ b/test/bin/nvidia-container-runtime-hook @@ -0,0 +1,2 @@ +#!/bin/bash +echo mock hook diff --git a/test/bin/runc b/test/bin/runc new file mode 100755 index 00000000..e605fc8d --- /dev/null +++ b/test/bin/runc @@ -0,0 +1,2 @@ +#!/bin/bash +echo mock runc diff --git a/test/input/nvidia-container-runtime/config.toml b/test/input/nvidia-container-runtime/config.toml new file mode 100644 index 00000000..e69de29b diff --git a/test/input/test_spec.json b/test/input/test_spec.json new file mode 100644 index 00000000..35a1a875 --- /dev/null +++ b/test/input/test_spec.json @@ -0,0 +1,178 @@ +{ + "ociVersion": "1.0.1-dev", + "process": { + "terminal": true, + "user": { + "uid": 0, + "gid": 0 + }, + "args": [ + "sh" + ], + "env": [ + "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", + "TERM=xterm" + ], + "cwd": "/", + "capabilities": { + "bounding": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "effective": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "inheritable": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "permitted": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "ambient": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ] + }, + "rlimits": [ + { + "type": "RLIMIT_NOFILE", + "hard": 1024, + "soft": 1024 + } + ], + "noNewPrivileges": true + }, + "root": { + "path": "rootfs", + "readonly": true + }, + "hostname": "runc", + "mounts": [ + { + "destination": "/proc", + "type": "proc", + "source": "proc" + }, + { + "destination": "/dev", + "type": "tmpfs", + "source": "tmpfs", + "options": [ + "nosuid", + "strictatime", + "mode=755", + "size=65536k" + ] + }, + { + "destination": "/dev/pts", + "type": "devpts", + "source": "devpts", + "options": [ + "nosuid", + "noexec", + "newinstance", + "ptmxmode=0666", + "mode=0620", + "gid=5" + ] + }, + { + "destination": "/dev/shm", + "type": "tmpfs", + "source": "shm", + "options": [ + "nosuid", + "noexec", + "nodev", + "mode=1777", + "size=65536k" + ] + }, + { + "destination": "/dev/mqueue", + "type": "mqueue", + "source": "mqueue", + "options": [ + "nosuid", + "noexec", + "nodev" + ] + }, + { + "destination": "/sys", + "type": "sysfs", + "source": "sysfs", + "options": [ + "nosuid", + "noexec", + "nodev", + "ro" + ] + }, + { + "destination": "/sys/fs/cgroup", + "type": "cgroup", + "source": "cgroup", + "options": [ + "nosuid", + "noexec", + "nodev", + "relatime", + "ro" + ] + } + ], + "linux": { + "resources": { + "devices": [ + { + "allow": false, + "access": "rwm" + } + ] + }, + "namespaces": [ + { + "type": "pid" + }, + { + "type": "network" + }, + { + "type": "ipc" + }, + { + "type": "uts" + }, + { + "type": "mount" + } + ], + "maskedPaths": [ + "/proc/kcore", + "/proc/latency_stats", + "/proc/timer_list", + "/proc/timer_stats", + "/proc/sched_debug", + "/sys/firmware", + "/proc/scsi" + ], + "readonlyPaths": [ + "/proc/asound", + "/proc/bus", + "/proc/fs", + "/proc/irq", + "/proc/sys", + "/proc/sysrq-trigger" + ] + } +} \ No newline at end of file