diff --git a/cmd/nvidia-container-runtime/nvcr.go b/cmd/nvidia-container-runtime/nvcr.go index 1eb23a26..6c4185b3 100644 --- a/cmd/nvidia-container-runtime/nvcr.go +++ b/cmd/nvidia-container-runtime/nvcr.go @@ -66,22 +66,9 @@ func (r nvidiaContainerRuntime) Exec(args []string) error { // modificationRequired checks the intput arguments to determine whether a modification // to the OCI spec is required. func (r nvidiaContainerRuntime) modificationRequired(args []string) bool { - var previousWasBundle bool - for _, a := range args { - // We check for '--bundle create' explicitly to ensure that we - // don't inadvertently trigger a modification if the bundle directory - // is specified as `create` - if !previousWasBundle && isBundleFlag(a) { - previousWasBundle = true - continue - } - - if !previousWasBundle && a == "create" { - r.logger.Infof("'create' command detected; modification required") - return true - } - - previousWasBundle = false + if oci.HasCreateSubcommand(args) { + r.logger.Infof("'create' command detected; modification required") + return true } r.logger.Infof("No modification required") diff --git a/cmd/nvidia-container-runtime/nvcr_test.go b/cmd/nvidia-container-runtime/nvcr_test.go index 5650e635..16e1b447 100644 --- a/cmd/nvidia-container-runtime/nvcr_test.go +++ b/cmd/nvidia-container-runtime/nvcr_test.go @@ -27,32 +27,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestArgsGetConfigFilePath(t *testing.T) { - testCases := []struct { - bundleDir string - ociSpecPath string - }{ - { - ociSpecPath: "config.json", - }, - { - bundleDir: "/foo/bar", - ociSpecPath: "/foo/bar/config.json", - }, - { - bundleDir: "/foo/bar/", - ociSpecPath: "/foo/bar/config.json", - }, - } - - for i, tc := range testCases { - cp, err := getOCISpecFilePath(tc.bundleDir) - - require.NoErrorf(t, err, "%d: %v", i, tc) - require.Equalf(t, tc.ociSpecPath, cp, "%d: %v", i, tc) - } -} - func TestAddNvidiaHook(t *testing.T) { logger, logHook := testlog.NewNullLogger() shim := nvidiaContainerRuntime{ @@ -181,9 +155,14 @@ func TestNvidiaContainerRuntime(t *testing.T) { tc.shim.logger = logger hook.Reset() - spec := &specs.Spec{} - ociMock := oci.NewMockSpec(spec, tc.writeError, tc.modifyError) - + ociMock := &oci.SpecMock{ + ModifyFunc: func(specModifier oci.SpecModifier) error { + return tc.modifyError + }, + FlushFunc: func() error { + return tc.writeError + }, + } require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc) tc.shim.ociSpec = ociMock @@ -197,18 +176,16 @@ func TestNvidiaContainerRuntime(t *testing.T) { } if tc.shouldModify { - require.Equal(t, 1, ociMock.MockModify.Callcount, "%d: %v", i, tc) - require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "%d: %v", i, tc) + require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc) } else { - require.Equal(t, 0, ociMock.MockModify.Callcount, "%d: %v", i, tc) - require.Nil(t, spec.Hooks, "%d: %v", i, tc) + require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc) } writeExpected := tc.shouldModify && tc.modifyError == nil if writeExpected { - require.Equal(t, 1, ociMock.MockFlush.Callcount, "%d: %v", i, tc) + require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc) } else { - require.Equal(t, 0, ociMock.MockFlush.Callcount, "%d: %v", i, tc) + require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc) } } } diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go index 151de77a..0af37f6e 100644 --- a/cmd/nvidia-container-runtime/runtime_factory.go +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -18,9 +18,6 @@ package main import ( "fmt" - "os/exec" - "path/filepath" - "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) @@ -53,15 +50,15 @@ func newRuntime(argv []string) (oci.Runtime, error) { // newOCISpec constructs an OCI spec for the provided arguments func newOCISpec(argv []string) (oci.Spec, error) { - bundlePath, err := getBundlePath(argv) + bundleDir, err := oci.GetBundleDir(argv) if err != nil { return nil, fmt.Errorf("error parsing command line arguments: %v", err) } + logger.Infof("Using bundle directory: %v", bundleDir) + + ociSpecPath := oci.GetSpecFilePath(bundleDir) + logger.Infof("Using OCI specification file path: %v", ociSpecPath) - ociSpecPath, err := getOCISpecFilePath(bundlePath) - if err != nil { - return nil, fmt.Errorf("error getting OCI specification file path: %v", err) - } ociSpec := oci.NewSpecFromFile(ociSpecPath) return ociSpec, nil @@ -69,98 +66,9 @@ func newOCISpec(argv []string) (oci.Spec, error) { // newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime func newRuncRuntime() (oci.Runtime, error) { - runtimePath, err := findRunc() - if err != nil { - return nil, fmt.Errorf("error locating runtime: %v", err) - } - - runc, err := oci.NewSyscallExecRuntimeWithLogger(logger.Logger, runtimePath) - if err != nil { - return nil, fmt.Errorf("error constructing runtime: %v", err) - } - - return runc, nil -} - -// getBundlePath checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc. -// The following are supported: -// --bundle{{SEP}}BUNDLE_PATH -// -bundle{{SEP}}BUNDLE_PATH -// -b{{SEP}}BUNDLE_PATH -// where {{SEP}} is either ' ' or '=' -func getBundlePath(argv []string) (string, error) { - var bundlePath string - - for i := 0; i < len(argv); i++ { - param := argv[i] - - parts := strings.SplitN(param, "=", 2) - if !isBundleFlag(parts[0]) { - continue - } - - // The flag has the format --bundle=/path - if len(parts) == 2 { - bundlePath = parts[1] - continue - } - - // The flag has the format --bundle /path - if i+1 < len(argv) { - bundlePath = argv[i+1] - i++ - continue - } - - // --bundle / -b was the last element of argv - return "", fmt.Errorf("bundle option requires an argument") - } - - return bundlePath, nil -} - -// findRunc locates runc in the path, returning the full path to the -// binary or an error. -func findRunc() (string, error) { - runtimeCandidates := []string{ + return oci.NewLowLevelRuntimeWithLogger( + logger.Logger, dockerRuncExecutableName, runcExecutableName, - } - - return findRuntime(runtimeCandidates) -} - -func findRuntime(runtimeCandidates []string) (string, error) { - for _, candidate := range runtimeCandidates { - logger.Infof("Looking for runtime binary '%v'", candidate) - runcPath, err := exec.LookPath(candidate) - if err == nil { - logger.Infof("Found runtime binary '%v'", runcPath) - return runcPath, nil - } - logger.Warnf("Runtime binary '%v' not found: %v", candidate, err) - } - - return "", fmt.Errorf("no runtime binary found from candidate list: %v", runtimeCandidates) -} - -func isBundleFlag(arg string) bool { - if !strings.HasPrefix(arg, "-") { - return false - } - - trimmed := strings.TrimLeft(arg, "-") - return trimmed == "b" || trimmed == "bundle" -} - -// getOCISpecFilePath returns the expected path to the OCI specification file for the given -// bundle directory. If the bundle directory is empty, only `config.json` is returned. -func getOCISpecFilePath(bundleDir string) (string, error) { - logger.Infof("Using bundle directory: %v", bundleDir) - - OCISpecFilePath := filepath.Join(bundleDir, ociSpecFileName) - - logger.Infof("Using OCI specification file path: %v", OCISpecFilePath) - - return OCISpecFilePath, nil + ) } diff --git a/cmd/nvidia-container-runtime/runtime_factory_test.go b/cmd/nvidia-container-runtime/runtime_factory_test.go index b5a6d461..07d02232 100644 --- a/cmd/nvidia-container-runtime/runtime_factory_test.go +++ b/cmd/nvidia-container-runtime/runtime_factory_test.go @@ -17,10 +17,8 @@ package main import ( - "path/filepath" "testing" - testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) @@ -30,163 +28,3 @@ func TestConstructor(t *testing.T) { require.NoError(t, err) require.NotNil(t, shim) } - -func TestGetBundlePath(t *testing.T) { - type expected struct { - bundle string - isError bool - } - testCases := []struct { - argv []string - expected expected - }{ - { - argv: []string{}, - }, - { - argv: []string{"create"}, - }, - { - argv: []string{"--bundle"}, - expected: expected{ - isError: true, - }, - }, - { - argv: []string{"-b"}, - expected: expected{ - isError: true, - }, - }, - { - argv: []string{"--bundle", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"--not-bundle", "/foo/bar"}, - }, - { - argv: []string{"--"}, - }, - { - argv: []string{"-bundle", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"--bundle=/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"-b=/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"-b=/foo/=bar"}, - expected: expected{ - bundle: "/foo/=bar", - }, - }, - { - argv: []string{"-b", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"create", "-b", "/foo/bar"}, - expected: expected{ - bundle: "/foo/bar", - }, - }, - { - argv: []string{"-b", "create", "create"}, - expected: expected{ - bundle: "create", - }, - }, - { - argv: []string{"-b=create", "create"}, - expected: expected{ - bundle: "create", - }, - }, - { - argv: []string{"-b", "create"}, - expected: expected{ - bundle: "create", - }, - }, - } - - for i, tc := range testCases { - bundle, err := getBundlePath(tc.argv) - - if tc.expected.isError { - require.Errorf(t, err, "%d: %v", i, tc) - } else { - require.NoErrorf(t, err, "%d: %v", i, tc) - } - - require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc) - } -} - -func TestFindRunc(t *testing.T) { - testLogger, _ := testlog.NewNullLogger() - logger.Logger = testLogger - - runcPath, err := findRunc() - require.NoError(t, err) - require.Equal(t, filepath.Join(cfg.binPath, runcExecutableName), runcPath) -} - -func TestFindRuntime(t *testing.T) { - testLogger, _ := testlog.NewNullLogger() - logger.Logger = testLogger - - testCases := []struct { - candidates []string - expectedPath string - }{ - { - candidates: []string{}, - }, - { - candidates: []string{"not-runc"}, - }, - { - candidates: []string{"not-runc", "also-not-runc"}, - }, - { - candidates: []string{runcExecutableName}, - expectedPath: filepath.Join(cfg.binPath, runcExecutableName), - }, - { - candidates: []string{runcExecutableName, "not-runc"}, - expectedPath: filepath.Join(cfg.binPath, runcExecutableName), - }, - { - candidates: []string{"not-runc", runcExecutableName}, - expectedPath: filepath.Join(cfg.binPath, runcExecutableName), - }, - } - - for i, tc := range testCases { - runcPath, err := findRuntime(tc.candidates) - if tc.expectedPath == "" { - require.Error(t, err, "%d: %v", i, tc) - } else { - require.NoError(t, err, "%d: %v", i, tc) - } - require.Equal(t, tc.expectedPath, runcPath, "%d: %v", i, tc) - } - -} diff --git a/internal/oci/args.go b/internal/oci/args.go new file mode 100644 index 00000000..de85d9cc --- /dev/null +++ b/internal/oci/args.go @@ -0,0 +1,115 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "path/filepath" + "strings" +) + +const ( + specFileName = "config.json" +) + +// GetBundleDir returns the bundle directory or default depending on the +// supplied command line arguments. +func GetBundleDir(args []string) (string, error) { + bundleDir, err := GetBundleDirFromArgs(args) + if err != nil { + return "", fmt.Errorf("error getting bundle dir from args: %v", err) + } + + return bundleDir, nil +} + +// GetBundleDirFromArgs checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc. +// The following are supported: +// --bundle{{SEP}}BUNDLE_PATH +// -bundle{{SEP}}BUNDLE_PATH +// -b{{SEP}}BUNDLE_PATH +// where {{SEP}} is either ' ' or '=' +func GetBundleDirFromArgs(args []string) (string, error) { + var bundleDir string + + for i := 0; i < len(args); i++ { + param := args[i] + + parts := strings.SplitN(param, "=", 2) + if !IsBundleFlag(parts[0]) { + continue + } + + // The flag has the format --bundle=/path + if len(parts) == 2 { + bundleDir = parts[1] + continue + } + + // The flag has the format --bundle /path + if i+1 < len(args) { + bundleDir = args[i+1] + i++ + continue + } + + // --bundle / -b was the last element of args + return "", fmt.Errorf("bundle option requires an argument") + } + + return bundleDir, nil +} + +// GetSpecFilePath returns the expected path to the OCI specification file for the given +// bundle directory. +func GetSpecFilePath(bundleDir string) string { + specFilePath := filepath.Join(bundleDir, specFileName) + return specFilePath +} + +// IsBundleFlag is a helper function that checks wither the specified argument represents +// a bundle flag (--bundle or -b) +func IsBundleFlag(arg string) bool { + if !strings.HasPrefix(arg, "-") { + return false + } + + trimmed := strings.TrimLeft(arg, "-") + return trimmed == "b" || trimmed == "bundle" +} + +// HasCreateSubcommand checks the supplied arguments for a 'create' subcommand +func HasCreateSubcommand(args []string) bool { + var previousWasBundle bool + for _, a := range args { + // We check for '--bundle create' explicitly to ensure that we + // don't inadvertently trigger a modification if the bundle directory + // is specified as `create` + if !previousWasBundle && IsBundleFlag(a) { + previousWasBundle = true + continue + } + + if !previousWasBundle && a == "create" { + return true + } + + previousWasBundle = false + } + + return false +} diff --git a/internal/oci/args_test.go b/internal/oci/args_test.go new file mode 100644 index 00000000..562dd0a9 --- /dev/null +++ b/internal/oci/args_test.go @@ -0,0 +1,184 @@ +package oci + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetBundleDir(t *testing.T) { + type expected struct { + bundle string + isError bool + } + testCases := []struct { + argv []string + expected expected + }{ + { + argv: []string{}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"create"}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"--bundle"}, + expected: expected{ + isError: true, + }, + }, + { + argv: []string{"-b"}, + expected: expected{ + isError: true, + }, + }, + { + argv: []string{"--bundle", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"--not-bundle", "/foo/bar"}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"--"}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"-bundle", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"--bundle=/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b=/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b=/foo/=bar"}, + expected: expected{ + bundle: "/foo/=bar", + }, + }, + { + argv: []string{"-b", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"create", "-b", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b", "create", "create"}, + expected: expected{ + bundle: "create", + }, + }, + { + argv: []string{"-b=create", "create"}, + expected: expected{ + bundle: "create", + }, + }, + { + argv: []string{"-b", "create"}, + expected: expected{ + bundle: "create", + }, + }, + } + + for i, tc := range testCases { + bundle, err := GetBundleDir(tc.argv) + + if tc.expected.isError { + require.Errorf(t, err, "%d: %v", i, tc) + } else { + require.NoErrorf(t, err, "%d: %v", i, tc) + } + + require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc) + } +} + +func TestGetSpecFilePathAppendsFilename(t *testing.T) { + testCases := []struct { + bundleDir string + expected string + }{ + { + bundleDir: "", + expected: "config.json", + }, + { + bundleDir: "/not/empty/", + expected: "/not/empty/config.json", + }, + { + bundleDir: "not/absolute", + expected: "not/absolute/config.json", + }, + } + + for i, tc := range testCases { + specPath := GetSpecFilePath(tc.bundleDir) + + require.Equalf(t, tc.expected, specPath, "%d: %v", i, tc) + } +} + +func TestHasCreateSubcommand(t *testing.T) { + testCases := []struct { + args []string + shouldModify bool + }{ + { + shouldModify: false, + }, + { + args: []string{"create"}, + shouldModify: true, + }, + { + args: []string{"--bundle=create"}, + shouldModify: false, + }, + { + args: []string{"--bundle", "create"}, + shouldModify: false, + }, + { + args: []string{"create"}, + shouldModify: true, + }, + } + + for i, tc := range testCases { + require.Equal(t, tc.shouldModify, HasCreateSubcommand(tc.args), "%d: %v", i, tc) + } +} diff --git a/internal/oci/runtime.go b/internal/oci/runtime.go index 89df5aa1..438fc5d2 100644 --- a/internal/oci/runtime.go +++ b/internal/oci/runtime.go @@ -16,6 +16,8 @@ package oci +//go:generate moq -stub -out runtime_mock.go . Runtime + // Runtime is an interface for a runtime shim. The Exec method accepts a list // of command line arguments, and returns an error / nil. type Runtime interface { diff --git a/internal/oci/runtime_exec.go b/internal/oci/runtime_exec.go deleted file mode 100644 index 98415747..00000000 --- a/internal/oci/runtime_exec.go +++ /dev/null @@ -1,79 +0,0 @@ -/* -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -*/ - -package oci - -import ( - "fmt" - "os" - "syscall" - - log "github.com/sirupsen/logrus" -) - -// SyscallExecRuntime wraps the path that a binary and defines the semanitcs for how to exec into it. -// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the -// Runtime internface. -type SyscallExecRuntime struct { - logger *log.Logger - path string - // exec is used for testing. This defaults to syscall.Exec - exec func(argv0 string, argv []string, envv []string) error -} - -var _ Runtime = (*SyscallExecRuntime)(nil) - -// NewSyscallExecRuntime creates a SyscallExecRuntime for the specified path with the standard logger -func NewSyscallExecRuntime(path string) (Runtime, error) { - return NewSyscallExecRuntimeWithLogger(log.StandardLogger(), path) -} - -// NewSyscallExecRuntimeWithLogger creates a SyscallExecRuntime for the specified logger and path -func NewSyscallExecRuntimeWithLogger(logger *log.Logger, path string) (Runtime, error) { - info, err := os.Stat(path) - if err != nil { - return nil, fmt.Errorf("invalid path '%v': %v", path, err) - } - if info.IsDir() || info.Mode()&0111 == 0 { - return nil, fmt.Errorf("specified path '%v' is not an executable file", path) - } - - shim := SyscallExecRuntime{ - logger: logger, - path: path, - exec: syscall.Exec, - } - - return &shim, nil -} - -// Exec exces into the binary at the path from the SyscallExecRuntime struct, passing it the supplied arguments -// after ensuring that the first argument is the path of the target binary. -func (s SyscallExecRuntime) Exec(args []string) error { - runtimeArgs := []string{s.path} - if len(args) > 1 { - runtimeArgs = append(runtimeArgs, args[1:]...) - } - - err := s.exec(s.path, runtimeArgs, os.Environ()) - if err != nil { - return fmt.Errorf("could not exec '%v': %v", s.path, err) - } - - // syscall.Exec is not expected to return. This is an error state regardless of whether - // err is nil or not. - return fmt.Errorf("unexpected return from exec '%v'", s.path) -} diff --git a/internal/oci/runtime_exec_test.go b/internal/oci/runtime_exec_test.go deleted file mode 100644 index 83ac64a2..00000000 --- a/internal/oci/runtime_exec_test.go +++ /dev/null @@ -1,100 +0,0 @@ -/* -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -*/ -package oci - -import ( - "fmt" - "strings" - "testing" - - testlog "github.com/sirupsen/logrus/hooks/test" - "github.com/stretchr/testify/require" -) - -func TestSyscallExecConstructor(t *testing.T) { - r, err := NewSyscallExecRuntime("////an/invalid/path") - require.Error(t, err) - require.Nil(t, r) - - r, err = NewSyscallExecRuntime("/tmp") - require.Error(t, err) - require.Nil(t, r) - - r, err = NewSyscallExecRuntime("/dev/null") - require.Error(t, err) - require.Nil(t, r) - - r, err = NewSyscallExecRuntime("/bin/sh") - require.NoError(t, err) - - f, ok := r.(*SyscallExecRuntime) - require.True(t, ok) - - require.Equal(t, "/bin/sh", f.path) -} - -func TestSyscallExecForwardsArgs(t *testing.T) { - logger, _ := testlog.NewNullLogger() - f := SyscallExecRuntime{ - logger: logger, - path: "runtime", - } - - testCases := []struct { - returnError error - args []string - errorPrefix string - }{ - { - returnError: nil, - errorPrefix: "unexpected return from exec", - }, - { - returnError: fmt.Errorf("error from exec"), - errorPrefix: "could not exec", - }, - { - returnError: nil, - args: []string{"otherargv0"}, - errorPrefix: "unexpected return from exec", - }, - { - returnError: nil, - args: []string{"otherargv0", "arg1", "arg2", "arg3"}, - errorPrefix: "unexpected return from exec", - }, - } - - for i, tc := range testCases { - execMock := WithMockExec(f, tc.returnError) - - err := execMock.Exec(tc.args) - - require.Errorf(t, err, "%d: %v", i, tc) - require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc) - if tc.returnError != nil { - require.Truef(t, strings.HasSuffix(err.Error(), tc.returnError.Error()), "%d: %v", i, tc) - } - - require.Equalf(t, f.path, execMock.argv0, "%d: %v", i, tc) - require.Equalf(t, f.path, execMock.argv[0], "%d: %v", i, tc) - - require.LessOrEqualf(t, len(tc.args), len(execMock.argv), "%d: %v", i, tc) - if len(tc.args) > 1 { - require.Equalf(t, tc.args[1:], execMock.argv[1:], "%d: %v", i, tc) - } - } -} diff --git a/internal/oci/runtime_low_level.go b/internal/oci/runtime_low_level.go new file mode 100644 index 00000000..56c9e367 --- /dev/null +++ b/internal/oci/runtime_low_level.go @@ -0,0 +1,61 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "os/exec" + + log "github.com/sirupsen/logrus" +) + +// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable. +// The executable specified is taken from the list of supplied candidates, with the first match +// present in the PATH being selected. +func NewLowLevelRuntime(candidates ...string) (Runtime, error) { + return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...) +} + +// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger. +func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) { + runtimePath, err := findRuntime(candidates) + if err != nil { + return nil, fmt.Errorf("error locating runtime: %v", err) + } + + return NewRuntimeForPathWithLogger(logger, runtimePath) +} + +// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH. +// The absolute path to the first match is returned. +func findRuntime(candidates []string) (string, error) { + if len(candidates) == 0 { + return "", fmt.Errorf("at least one runtime candidate must be specified") + } + + for _, candidate := range candidates { + log.Infof("Looking for runtime binary '%v'", candidate) + runcPath, err := exec.LookPath(candidate) + if err == nil { + log.Infof("Found runtime binary '%v'", runcPath) + return runcPath, nil + } + log.Warnf("Runtime binary '%v' not found: %v", candidate, err) + } + + return "", fmt.Errorf("no runtime binary found from candidate list: %v", candidates) +} diff --git a/internal/oci/runtime_mock.go b/internal/oci/runtime_mock.go index e09cfb79..2887fb22 100644 --- a/internal/oci/runtime_mock.go +++ b/internal/oci/runtime_mock.go @@ -1,49 +1,76 @@ -/* -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -*/ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq package oci -// MockExecRuntime wraps a SyscallExecRuntime, intercepting the exec call for testing -type MockExecRuntime struct { - SyscallExecRuntime - execMock -} +import ( + "sync" +) -// WithMockExec wraps a specified SyscallExecRuntime with a mocked exec function for testing -func WithMockExec(e SyscallExecRuntime, execResult error) *MockExecRuntime { - m := MockExecRuntime{ - SyscallExecRuntime: e, - execMock: execMock{result: execResult}, +// Ensure, that RuntimeMock does implement Runtime. +// If this is not the case, regenerate this file with moq. +var _ Runtime = &RuntimeMock{} + +// RuntimeMock is a mock implementation of Runtime. +// +// func TestSomethingThatUsesRuntime(t *testing.T) { +// +// // make and configure a mocked Runtime +// mockedRuntime := &RuntimeMock{ +// ExecFunc: func(strings []string) error { +// panic("mock out the Exec method") +// }, +// } +// +// // use mockedRuntime in code that requires Runtime +// // and then make assertions. +// +// } +type RuntimeMock struct { + // ExecFunc mocks the Exec method. + ExecFunc func(strings []string) error + + // calls tracks calls to the methods. + calls struct { + // Exec holds details about calls to the Exec method. + Exec []struct { + // Strings is the strings argument value. + Strings []string + } } - // overrdie the exec function to the mocked exec function. - m.SyscallExecRuntime.exec = m.execMock.exec - return &m + lockExec sync.RWMutex } -type execMock struct { - argv0 string - argv []string - envv []string - result error +// Exec calls ExecFunc. +func (mock *RuntimeMock) Exec(strings []string) error { + callInfo := struct { + Strings []string + }{ + Strings: strings, + } + mock.lockExec.Lock() + mock.calls.Exec = append(mock.calls.Exec, callInfo) + mock.lockExec.Unlock() + if mock.ExecFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ExecFunc(strings) } -func (m *execMock) exec(argv0 string, argv []string, envv []string) error { - m.argv0 = argv0 - m.argv = argv - m.envv = envv - - return m.result +// ExecCalls gets all the calls that were made to Exec. +// Check the length with: +// len(mockedRuntime.ExecCalls()) +func (mock *RuntimeMock) ExecCalls() []struct { + Strings []string +} { + var calls []struct { + Strings []string + } + mock.lockExec.RLock() + calls = mock.calls.Exec + mock.lockExec.RUnlock() + return calls } diff --git a/internal/oci/runtime_path.go b/internal/oci/runtime_path.go new file mode 100644 index 00000000..abf225b2 --- /dev/null +++ b/internal/oci/runtime_path.go @@ -0,0 +1,70 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "os" + + log "github.com/sirupsen/logrus" +) + +// pathRuntime wraps the path that a binary and defines the semanitcs for how to exec into it. +// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the +// Runtime internface. +type pathRuntime struct { + logger *log.Logger + path string + execRuntime Runtime +} + +var _ Runtime = (*pathRuntime)(nil) + +// NewRuntimeForPath creates a Runtime for the specified path with the standard logger +func NewRuntimeForPath(path string) (Runtime, error) { + return NewRuntimeForPathWithLogger(log.StandardLogger(), path) +} + +// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path +func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) { + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("invalid path '%v': %v", path, err) + } + if info.IsDir() || info.Mode()&0111 == 0 { + return nil, fmt.Errorf("specified path '%v' is not an executable file", path) + } + + shim := pathRuntime{ + logger: logger, + path: path, + execRuntime: syscallExec{}, + } + + return &shim, nil +} + +// Exec exces into the binary at the path from the pathRuntime struct, passing it the supplied arguments +// after ensuring that the first argument is the path of the target binary. +func (s pathRuntime) Exec(args []string) error { + runtimeArgs := []string{s.path} + if len(args) > 1 { + runtimeArgs = append(runtimeArgs, args[1:]...) + } + + return s.execRuntime.Exec(runtimeArgs) +} diff --git a/internal/oci/runtime_path_test.go b/internal/oci/runtime_path_test.go new file mode 100644 index 00000000..0d936a3f --- /dev/null +++ b/internal/oci/runtime_path_test.go @@ -0,0 +1,97 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ +package oci + +import ( + "fmt" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestPathRuntimeConstructor(t *testing.T) { + r, err := NewRuntimeForPath("////an/invalid/path") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewRuntimeForPath("/tmp") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewRuntimeForPath("/dev/null") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewRuntimeForPath("/bin/sh") + require.NoError(t, err) + + f, ok := r.(*pathRuntime) + require.True(t, ok) + + require.Equal(t, "/bin/sh", f.path) +} + +func TestPathRuntimeForwardsArgs(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + testCases := []struct { + execRuntimeError error + args []string + }{ + {}, + { + args: []string{"shouldBeReplaced"}, + }, + { + args: []string{"shouldBeReplaced", "arg1"}, + }, + { + execRuntimeError: fmt.Errorf("exec error"), + }, + } + + for _, tc := range testCases { + mockedRuntime := &RuntimeMock{ + ExecFunc: func(strings []string) error { + return tc.execRuntimeError + }, + } + r := pathRuntime{ + logger: logger, + path: "runtime", + execRuntime: mockedRuntime, + } + err := r.Exec(tc.args) + + require.ErrorIs(t, err, tc.execRuntimeError) + + calls := mockedRuntime.ExecCalls() + require.Len(t, calls, 1) + + numArgs := len(tc.args) + if numArgs == 0 { + numArgs = 1 + } + + require.Len(t, calls[0].Strings, numArgs) + require.Equal(t, "runtime", calls[0].Strings[0]) + + if numArgs > 1 { + require.EqualValues(t, tc.args[1:], calls[0].Strings[1:]) + } + } +} diff --git a/internal/oci/runtime_syscall_exec.go b/internal/oci/runtime_syscall_exec.go new file mode 100644 index 00000000..d752776a --- /dev/null +++ b/internal/oci/runtime_syscall_exec.go @@ -0,0 +1,38 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "os" + "syscall" +) + +type syscallExec struct{} + +var _ Runtime = (*syscallExec)(nil) + +func (r syscallExec) Exec(args []string) error { + err := syscall.Exec(args[0], args, os.Environ()) + if err != nil { + return fmt.Errorf("could not exec '%v': %v", args[0], err) + } + + // syscall.Exec is not expected to return. This is an error state regardless of whether + // err is nil or not. + return fmt.Errorf("unexpected return from exec '%v'", args[0]) +} diff --git a/internal/oci/spec.go b/internal/oci/spec.go index 0c3aaf4d..259ce054 100644 --- a/internal/oci/spec.go +++ b/internal/oci/spec.go @@ -17,10 +17,6 @@ package oci import ( - "encoding/json" - "fmt" - "os" - oci "github.com/opencontainers/runtime-spec/specs-go" ) @@ -28,75 +24,12 @@ import ( // error. The intention is that the function would modify the spec in-place. type SpecModifier func(*oci.Spec) error +//go:generate moq -stub -out spec_mock.go . Spec + // Spec defines the operations to be performed on an OCI specification type Spec interface { Load() error Flush() error Modify(SpecModifier) error -} - -type fileSpec struct { - *oci.Spec - path string -} - -var _ Spec = (*fileSpec)(nil) - -// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec. -// This can be used to read from the file, modify the spec, and write to the -// same file. -func NewSpecFromFile(filepath string) Spec { - oci := fileSpec{ - path: filepath, - } - - return &oci -} - -// Load reads the contents of an OCI spec from file to be referenced internally. -// The file is opened "read-only" -func (s *fileSpec) Load() error { - specFile, err := os.Open(s.path) - if err != nil { - return fmt.Errorf("error opening OCI specification file: %v", err) - } - defer specFile.Close() - - decoder := json.NewDecoder(specFile) - - var spec oci.Spec - err = decoder.Decode(&spec) - if err != nil { - return fmt.Errorf("error reading OCI specification from file: %v", err) - } - - s.Spec = &spec - return nil -} - -// Modify applies the specified SpecModifier to the stored OCI specification. -func (s *fileSpec) Modify(f SpecModifier) error { - if s.Spec == nil { - return fmt.Errorf("no spec loaded for modification") - } - return f(s.Spec) -} - -// Flush writes the stored OCI specification to the filepath specifed by the path member. -// The file is truncated upon opening, overwriting any existing contents. -func (s fileSpec) Flush() error { - specFile, err := os.Create(s.path) - if err != nil { - return fmt.Errorf("error opening OCI specification file: %v", err) - } - defer specFile.Close() - - encoder := json.NewEncoder(specFile) - - err = encoder.Encode(s.Spec) - if err != nil { - return fmt.Errorf("error writing OCI specification to file: %v", err) - } - - return nil + LookupEnv(string) (string, bool) } diff --git a/internal/oci/spec_file.go b/internal/oci/spec_file.go new file mode 100644 index 00000000..886e2cb3 --- /dev/null +++ b/internal/oci/spec_file.go @@ -0,0 +1,153 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" + + oci "github.com/opencontainers/runtime-spec/specs-go" +) + +type fileSpec struct { + *oci.Spec + path string +} + +var _ Spec = (*fileSpec)(nil) + +// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the +// application +func NewSpecFromArgs(args []string) (Spec, string, error) { + bundleDir, err := GetBundleDir(args) + if err != nil { + return nil, "", fmt.Errorf("error getting bundle directory: %v", err) + } + + ociSpecPath := GetSpecFilePath(bundleDir) + + ociSpec := NewSpecFromFile(ociSpecPath) + + return ociSpec, bundleDir, nil +} + +// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec. +// This can be used to read from the file, modify the spec, and write to the +// same file. +func NewSpecFromFile(filepath string) Spec { + oci := fileSpec{ + path: filepath, + } + + return &oci +} + +// Load reads the contents of an OCI spec from file to be referenced internally. +// The file is opened "read-only" +func (s *fileSpec) Load() error { + specFile, err := os.Open(s.path) + if err != nil { + return fmt.Errorf("error opening OCI specification file: %v", err) + } + defer specFile.Close() + + return s.loadFrom(specFile) +} + +// loadFrom reads the contents of the OCI spec from the specified io.Reader. +func (s *fileSpec) loadFrom(reader io.Reader) error { + decoder := json.NewDecoder(reader) + + var spec oci.Spec + err := decoder.Decode(&spec) + if err != nil { + return fmt.Errorf("error reading OCI specification: %v", err) + } + + s.Spec = &spec + return nil +} + +// Modify applies the specified SpecModifier to the stored OCI specification. +func (s *fileSpec) Modify(f SpecModifier) error { + if s.Spec == nil { + return fmt.Errorf("no spec loaded for modification") + } + return f(s.Spec) +} + +// Flush writes the stored OCI specification to the filepath specifed by the path member. +// The file is truncated upon opening, overwriting any existing contents. +func (s fileSpec) Flush() error { + if s.Spec == nil { + return fmt.Errorf("no OCI specification loaded") + } + + specFile, err := os.Create(s.path) + if err != nil { + return fmt.Errorf("error opening OCI specification file: %v", err) + } + defer specFile.Close() + + return s.flushTo(specFile) +} + +// flushTo writes the stored OCI specification to the specified io.Writer. +func (s fileSpec) flushTo(writer io.Writer) error { + if s.Spec == nil { + return nil + } + encoder := json.NewEncoder(writer) + + err := encoder.Encode(s.Spec) + if err != nil { + return fmt.Errorf("error writing OCI specification: %v", err) + } + + return nil +} + +// LookupEnv mirrors os.LookupEnv for the OCI specification. It +// retrieves the value of the environment variable named +// by the key. If the variable is present in the environment the +// value (which may be empty) is returned and the boolean is true. +// Otherwise the returned value will be empty and the boolean will +// be false. +func (s fileSpec) LookupEnv(key string) (string, bool) { + if s.Spec == nil || s.Spec.Process == nil { + return "", false + } + + for _, env := range s.Spec.Process.Env { + if !strings.HasPrefix(env, key) { + continue + } + + parts := strings.SplitN(env, "=", 2) + if parts[0] == key { + if len(parts) < 2 { + return "", true + } + return parts[1], true + } + } + + return "", false +} diff --git a/internal/oci/spec_file_test.go b/internal/oci/spec_file_test.go new file mode 100644 index 00000000..fe273824 --- /dev/null +++ b/internal/oci/spec_file_test.go @@ -0,0 +1,252 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "bytes" + "fmt" + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/require" +) + +func TestLookupEnv(t *testing.T) { + const envName = "TEST_ENV" + testCases := []struct { + spec *specs.Spec + expectedValue string + expectedExits bool + }{ + { + // nil spec + spec: nil, + expectedValue: "", + expectedExits: false, + }, + { + // nil process + spec: &specs.Spec{}, + expectedValue: "", + expectedExits: false, + }, + { + // nil env + spec: &specs.Spec{ + Process: &specs.Process{}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // empty env + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // different env set + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same prefix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same suffix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // set blank + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV="}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set no-equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV"}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set value + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something"}}, + }, + expectedValue: "something", + expectedExits: true, + }, + { + // set with equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}}, + }, + expectedValue: "something=somethingelse", + expectedExits: true, + }, + } + + for i, tc := range testCases { + spec := fileSpec{ + Spec: tc.spec, + } + + value, exists := spec.LookupEnv(envName) + + require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc) + require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc) + } +} + +func TestLoadFrom(t *testing.T) { + testCases := []struct { + contents []byte + isError bool + spec *specs.Spec + }{ + { + contents: []byte{}, + isError: true, + }, + { + contents: []byte("{}"), + isError: false, + spec: &specs.Spec{}, + }, + } + + for i, tc := range testCases { + spec := fileSpec{} + err := spec.loadFrom(bytes.NewReader(tc.contents)) + + if tc.isError { + require.Error(t, err, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + } + + if tc.spec == nil { + require.Nil(t, spec.Spec, "%d: %v", i, tc) + } else { + require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc) + } + } +} + +func TestFlushTo(t *testing.T) { + testCases := []struct { + isError bool + spec *specs.Spec + contents string + }{ + { + spec: nil, + }, + { + spec: &specs.Spec{}, + contents: "{\"ociVersion\":\"\"}\n", + }, + } + + for i, tc := range testCases { + buffer := bytes.Buffer{} + + spec := fileSpec{Spec: tc.spec} + err := spec.flushTo(&buffer) + + if tc.isError { + require.Error(t, err, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + } + + require.EqualValues(t, tc.contents, buffer.String(), "%d: %v", i, tc) + } + + // Add a simple test for a writer that returns an error when writing + spec := fileSpec{Spec: &specs.Spec{}} + err := spec.flushTo(errorWriter{}) + require.Error(t, err) +} + +func TestModify(t *testing.T) { + + testCases := []struct { + spec *specs.Spec + modifierError error + }{ + { + spec: nil, + }, + { + spec: &specs.Spec{}, + }, + { + spec: &specs.Spec{}, + modifierError: fmt.Errorf("error in modifier"), + }, + } + + for i, tc := range testCases { + spec := fileSpec{Spec: tc.spec} + + modifier := func(spec *specs.Spec) error { + if tc.modifierError == nil { + spec.Version = "updated" + } + return tc.modifierError + } + + err := spec.Modify(modifier) + + if tc.spec == nil { + require.Error(t, err, "%d: %v", i, tc) + } else if tc.modifierError != nil { + require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc) + require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc) + } + } +} + +// errorWriter implements the io.Writer interface, always returning an error when +// writing. +type errorWriter struct{} + +func (e errorWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("error writing") +} diff --git a/internal/oci/spec_mock.go b/internal/oci/spec_mock.go index 1247adaf..1656552c 100644 --- a/internal/oci/spec_mock.go +++ b/internal/oci/spec_mock.go @@ -1,70 +1,201 @@ -/* -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -*/ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq package oci import ( - oci "github.com/opencontainers/runtime-spec/specs-go" + "sync" ) -// MockSpec provides a simple mock for an OCI spec to be used in testing. -// It also implements the SpecModifier interface. -type MockSpec struct { - *oci.Spec - MockLoad mockFunc - MockFlush mockFunc - MockModify mockFunc -} +// Ensure, that SpecMock does implement Spec. +// If this is not the case, regenerate this file with moq. +var _ Spec = &SpecMock{} -var _ Spec = (*MockSpec)(nil) +// SpecMock is a mock implementation of Spec. +// +// func TestSomethingThatUsesSpec(t *testing.T) { +// +// // make and configure a mocked Spec +// mockedSpec := &SpecMock{ +// FlushFunc: func() error { +// panic("mock out the Flush method") +// }, +// LoadFunc: func() error { +// panic("mock out the Load method") +// }, +// LookupEnvFunc: func(s string) (string, bool) { +// panic("mock out the LookupEnv method") +// }, +// ModifyFunc: func(specModifier SpecModifier) error { +// panic("mock out the Modify method") +// }, +// } +// +// // use mockedSpec in code that requires Spec +// // and then make assertions. +// +// } +type SpecMock struct { + // FlushFunc mocks the Flush method. + FlushFunc func() error -// NewMockSpec constructs a MockSpec to be used in testing as a Spec -func NewMockSpec(spec *oci.Spec, flushResult error, modifyResult error) *MockSpec { - s := MockSpec{ - Spec: spec, - MockFlush: mockFunc{result: flushResult}, - MockModify: mockFunc{result: modifyResult}, + // LoadFunc mocks the Load method. + LoadFunc func() error + + // LookupEnvFunc mocks the LookupEnv method. + LookupEnvFunc func(s string) (string, bool) + + // ModifyFunc mocks the Modify method. + ModifyFunc func(specModifier SpecModifier) error + + // calls tracks calls to the methods. + calls struct { + // Flush holds details about calls to the Flush method. + Flush []struct { + } + // Load holds details about calls to the Load method. + Load []struct { + } + // LookupEnv holds details about calls to the LookupEnv method. + LookupEnv []struct { + // S is the s argument value. + S string + } + // Modify holds details about calls to the Modify method. + Modify []struct { + // SpecModifier is the specModifier argument value. + SpecModifier SpecModifier + } } - - return &s + lockFlush sync.RWMutex + lockLoad sync.RWMutex + lockLookupEnv sync.RWMutex + lockModify sync.RWMutex } -// Load invokes the mocked Load function to return the predefined error / result -func (s *MockSpec) Load() error { - return s.MockLoad.call() +// Flush calls FlushFunc. +func (mock *SpecMock) Flush() error { + callInfo := struct { + }{} + mock.lockFlush.Lock() + mock.calls.Flush = append(mock.calls.Flush, callInfo) + mock.lockFlush.Unlock() + if mock.FlushFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.FlushFunc() } -// Flush invokes the mocked Load function to return the predefined error / result -func (s *MockSpec) Flush() error { - return s.MockFlush.call() +// FlushCalls gets all the calls that were made to Flush. +// Check the length with: +// len(mockedSpec.FlushCalls()) +func (mock *SpecMock) FlushCalls() []struct { +} { + var calls []struct { + } + mock.lockFlush.RLock() + calls = mock.calls.Flush + mock.lockFlush.RUnlock() + return calls } -// Modify applies the specified SpecModifier to the spec and invokes the -// mocked modify function to return the predefined error / result. -func (s *MockSpec) Modify(f SpecModifier) error { - f(s.Spec) - return s.MockModify.call() +// Load calls LoadFunc. +func (mock *SpecMock) Load() error { + callInfo := struct { + }{} + mock.lockLoad.Lock() + mock.calls.Load = append(mock.calls.Load, callInfo) + mock.lockLoad.Unlock() + if mock.LoadFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.LoadFunc() } -type mockFunc struct { - Callcount int - result error +// LoadCalls gets all the calls that were made to Load. +// Check the length with: +// len(mockedSpec.LoadCalls()) +func (mock *SpecMock) LoadCalls() []struct { +} { + var calls []struct { + } + mock.lockLoad.RLock() + calls = mock.calls.Load + mock.lockLoad.RUnlock() + return calls } -func (m *mockFunc) call() error { - m.Callcount++ - return m.result +// LookupEnv calls LookupEnvFunc. +func (mock *SpecMock) LookupEnv(s string) (string, bool) { + callInfo := struct { + S string + }{ + S: s, + } + mock.lockLookupEnv.Lock() + mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo) + mock.lockLookupEnv.Unlock() + if mock.LookupEnvFunc == nil { + var ( + sOut string + bOut bool + ) + return sOut, bOut + } + return mock.LookupEnvFunc(s) +} + +// LookupEnvCalls gets all the calls that were made to LookupEnv. +// Check the length with: +// len(mockedSpec.LookupEnvCalls()) +func (mock *SpecMock) LookupEnvCalls() []struct { + S string +} { + var calls []struct { + S string + } + mock.lockLookupEnv.RLock() + calls = mock.calls.LookupEnv + mock.lockLookupEnv.RUnlock() + return calls +} + +// Modify calls ModifyFunc. +func (mock *SpecMock) Modify(specModifier SpecModifier) error { + callInfo := struct { + SpecModifier SpecModifier + }{ + SpecModifier: specModifier, + } + mock.lockModify.Lock() + mock.calls.Modify = append(mock.calls.Modify, callInfo) + mock.lockModify.Unlock() + if mock.ModifyFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ModifyFunc(specModifier) +} + +// ModifyCalls gets all the calls that were made to Modify. +// Check the length with: +// len(mockedSpec.ModifyCalls()) +func (mock *SpecMock) ModifyCalls() []struct { + SpecModifier SpecModifier +} { + var calls []struct { + SpecModifier SpecModifier + } + mock.lockModify.RLock() + calls = mock.calls.Modify + mock.lockModify.RUnlock() + return calls }