diff --git a/internal/oci/args.go b/internal/oci/args.go new file mode 100644 index 00000000..de85d9cc --- /dev/null +++ b/internal/oci/args.go @@ -0,0 +1,115 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "path/filepath" + "strings" +) + +const ( + specFileName = "config.json" +) + +// GetBundleDir returns the bundle directory or default depending on the +// supplied command line arguments. +func GetBundleDir(args []string) (string, error) { + bundleDir, err := GetBundleDirFromArgs(args) + if err != nil { + return "", fmt.Errorf("error getting bundle dir from args: %v", err) + } + + return bundleDir, nil +} + +// GetBundleDirFromArgs checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc. +// The following are supported: +// --bundle{{SEP}}BUNDLE_PATH +// -bundle{{SEP}}BUNDLE_PATH +// -b{{SEP}}BUNDLE_PATH +// where {{SEP}} is either ' ' or '=' +func GetBundleDirFromArgs(args []string) (string, error) { + var bundleDir string + + for i := 0; i < len(args); i++ { + param := args[i] + + parts := strings.SplitN(param, "=", 2) + if !IsBundleFlag(parts[0]) { + continue + } + + // The flag has the format --bundle=/path + if len(parts) == 2 { + bundleDir = parts[1] + continue + } + + // The flag has the format --bundle /path + if i+1 < len(args) { + bundleDir = args[i+1] + i++ + continue + } + + // --bundle / -b was the last element of args + return "", fmt.Errorf("bundle option requires an argument") + } + + return bundleDir, nil +} + +// GetSpecFilePath returns the expected path to the OCI specification file for the given +// bundle directory. +func GetSpecFilePath(bundleDir string) string { + specFilePath := filepath.Join(bundleDir, specFileName) + return specFilePath +} + +// IsBundleFlag is a helper function that checks wither the specified argument represents +// a bundle flag (--bundle or -b) +func IsBundleFlag(arg string) bool { + if !strings.HasPrefix(arg, "-") { + return false + } + + trimmed := strings.TrimLeft(arg, "-") + return trimmed == "b" || trimmed == "bundle" +} + +// HasCreateSubcommand checks the supplied arguments for a 'create' subcommand +func HasCreateSubcommand(args []string) bool { + var previousWasBundle bool + for _, a := range args { + // We check for '--bundle create' explicitly to ensure that we + // don't inadvertently trigger a modification if the bundle directory + // is specified as `create` + if !previousWasBundle && IsBundleFlag(a) { + previousWasBundle = true + continue + } + + if !previousWasBundle && a == "create" { + return true + } + + previousWasBundle = false + } + + return false +} diff --git a/internal/oci/args_test.go b/internal/oci/args_test.go new file mode 100644 index 00000000..562dd0a9 --- /dev/null +++ b/internal/oci/args_test.go @@ -0,0 +1,184 @@ +package oci + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetBundleDir(t *testing.T) { + type expected struct { + bundle string + isError bool + } + testCases := []struct { + argv []string + expected expected + }{ + { + argv: []string{}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"create"}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"--bundle"}, + expected: expected{ + isError: true, + }, + }, + { + argv: []string{"-b"}, + expected: expected{ + isError: true, + }, + }, + { + argv: []string{"--bundle", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"--not-bundle", "/foo/bar"}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"--"}, + expected: expected{ + bundle: "", + }, + }, + { + argv: []string{"-bundle", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"--bundle=/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b=/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b=/foo/=bar"}, + expected: expected{ + bundle: "/foo/=bar", + }, + }, + { + argv: []string{"-b", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"create", "-b", "/foo/bar"}, + expected: expected{ + bundle: "/foo/bar", + }, + }, + { + argv: []string{"-b", "create", "create"}, + expected: expected{ + bundle: "create", + }, + }, + { + argv: []string{"-b=create", "create"}, + expected: expected{ + bundle: "create", + }, + }, + { + argv: []string{"-b", "create"}, + expected: expected{ + bundle: "create", + }, + }, + } + + for i, tc := range testCases { + bundle, err := GetBundleDir(tc.argv) + + if tc.expected.isError { + require.Errorf(t, err, "%d: %v", i, tc) + } else { + require.NoErrorf(t, err, "%d: %v", i, tc) + } + + require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc) + } +} + +func TestGetSpecFilePathAppendsFilename(t *testing.T) { + testCases := []struct { + bundleDir string + expected string + }{ + { + bundleDir: "", + expected: "config.json", + }, + { + bundleDir: "/not/empty/", + expected: "/not/empty/config.json", + }, + { + bundleDir: "not/absolute", + expected: "not/absolute/config.json", + }, + } + + for i, tc := range testCases { + specPath := GetSpecFilePath(tc.bundleDir) + + require.Equalf(t, tc.expected, specPath, "%d: %v", i, tc) + } +} + +func TestHasCreateSubcommand(t *testing.T) { + testCases := []struct { + args []string + shouldModify bool + }{ + { + shouldModify: false, + }, + { + args: []string{"create"}, + shouldModify: true, + }, + { + args: []string{"--bundle=create"}, + shouldModify: false, + }, + { + args: []string{"--bundle", "create"}, + shouldModify: false, + }, + { + args: []string{"create"}, + shouldModify: true, + }, + } + + for i, tc := range testCases { + require.Equal(t, tc.shouldModify, HasCreateSubcommand(tc.args), "%d: %v", i, tc) + } +} diff --git a/internal/oci/runtime.go b/internal/oci/runtime.go index 89df5aa1..438fc5d2 100644 --- a/internal/oci/runtime.go +++ b/internal/oci/runtime.go @@ -16,6 +16,8 @@ package oci +//go:generate moq -stub -out runtime_mock.go . Runtime + // Runtime is an interface for a runtime shim. The Exec method accepts a list // of command line arguments, and returns an error / nil. type Runtime interface { diff --git a/internal/oci/runtime_low_level.go b/internal/oci/runtime_low_level.go new file mode 100644 index 00000000..56c9e367 --- /dev/null +++ b/internal/oci/runtime_low_level.go @@ -0,0 +1,61 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "os/exec" + + log "github.com/sirupsen/logrus" +) + +// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable. +// The executable specified is taken from the list of supplied candidates, with the first match +// present in the PATH being selected. +func NewLowLevelRuntime(candidates ...string) (Runtime, error) { + return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...) +} + +// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger. +func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) { + runtimePath, err := findRuntime(candidates) + if err != nil { + return nil, fmt.Errorf("error locating runtime: %v", err) + } + + return NewRuntimeForPathWithLogger(logger, runtimePath) +} + +// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH. +// The absolute path to the first match is returned. +func findRuntime(candidates []string) (string, error) { + if len(candidates) == 0 { + return "", fmt.Errorf("at least one runtime candidate must be specified") + } + + for _, candidate := range candidates { + log.Infof("Looking for runtime binary '%v'", candidate) + runcPath, err := exec.LookPath(candidate) + if err == nil { + log.Infof("Found runtime binary '%v'", runcPath) + return runcPath, nil + } + log.Warnf("Runtime binary '%v' not found: %v", candidate, err) + } + + return "", fmt.Errorf("no runtime binary found from candidate list: %v", candidates) +} diff --git a/internal/oci/runtime_mock.go b/internal/oci/runtime_mock.go index e09cfb79..2887fb22 100644 --- a/internal/oci/runtime_mock.go +++ b/internal/oci/runtime_mock.go @@ -1,49 +1,76 @@ -/* -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -*/ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq package oci -// MockExecRuntime wraps a SyscallExecRuntime, intercepting the exec call for testing -type MockExecRuntime struct { - SyscallExecRuntime - execMock -} +import ( + "sync" +) -// WithMockExec wraps a specified SyscallExecRuntime with a mocked exec function for testing -func WithMockExec(e SyscallExecRuntime, execResult error) *MockExecRuntime { - m := MockExecRuntime{ - SyscallExecRuntime: e, - execMock: execMock{result: execResult}, +// Ensure, that RuntimeMock does implement Runtime. +// If this is not the case, regenerate this file with moq. +var _ Runtime = &RuntimeMock{} + +// RuntimeMock is a mock implementation of Runtime. +// +// func TestSomethingThatUsesRuntime(t *testing.T) { +// +// // make and configure a mocked Runtime +// mockedRuntime := &RuntimeMock{ +// ExecFunc: func(strings []string) error { +// panic("mock out the Exec method") +// }, +// } +// +// // use mockedRuntime in code that requires Runtime +// // and then make assertions. +// +// } +type RuntimeMock struct { + // ExecFunc mocks the Exec method. + ExecFunc func(strings []string) error + + // calls tracks calls to the methods. + calls struct { + // Exec holds details about calls to the Exec method. + Exec []struct { + // Strings is the strings argument value. + Strings []string + } } - // overrdie the exec function to the mocked exec function. - m.SyscallExecRuntime.exec = m.execMock.exec - return &m + lockExec sync.RWMutex } -type execMock struct { - argv0 string - argv []string - envv []string - result error +// Exec calls ExecFunc. +func (mock *RuntimeMock) Exec(strings []string) error { + callInfo := struct { + Strings []string + }{ + Strings: strings, + } + mock.lockExec.Lock() + mock.calls.Exec = append(mock.calls.Exec, callInfo) + mock.lockExec.Unlock() + if mock.ExecFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ExecFunc(strings) } -func (m *execMock) exec(argv0 string, argv []string, envv []string) error { - m.argv0 = argv0 - m.argv = argv - m.envv = envv - - return m.result +// ExecCalls gets all the calls that were made to Exec. +// Check the length with: +// len(mockedRuntime.ExecCalls()) +func (mock *RuntimeMock) ExecCalls() []struct { + Strings []string +} { + var calls []struct { + Strings []string + } + mock.lockExec.RLock() + calls = mock.calls.Exec + mock.lockExec.RUnlock() + return calls } diff --git a/internal/oci/runtime_path.go b/internal/oci/runtime_path.go new file mode 100644 index 00000000..abf225b2 --- /dev/null +++ b/internal/oci/runtime_path.go @@ -0,0 +1,70 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "os" + + log "github.com/sirupsen/logrus" +) + +// pathRuntime wraps the path that a binary and defines the semanitcs for how to exec into it. +// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the +// Runtime internface. +type pathRuntime struct { + logger *log.Logger + path string + execRuntime Runtime +} + +var _ Runtime = (*pathRuntime)(nil) + +// NewRuntimeForPath creates a Runtime for the specified path with the standard logger +func NewRuntimeForPath(path string) (Runtime, error) { + return NewRuntimeForPathWithLogger(log.StandardLogger(), path) +} + +// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path +func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) { + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("invalid path '%v': %v", path, err) + } + if info.IsDir() || info.Mode()&0111 == 0 { + return nil, fmt.Errorf("specified path '%v' is not an executable file", path) + } + + shim := pathRuntime{ + logger: logger, + path: path, + execRuntime: syscallExec{}, + } + + return &shim, nil +} + +// Exec exces into the binary at the path from the pathRuntime struct, passing it the supplied arguments +// after ensuring that the first argument is the path of the target binary. +func (s pathRuntime) Exec(args []string) error { + runtimeArgs := []string{s.path} + if len(args) > 1 { + runtimeArgs = append(runtimeArgs, args[1:]...) + } + + return s.execRuntime.Exec(runtimeArgs) +} diff --git a/internal/oci/runtime_path_test.go b/internal/oci/runtime_path_test.go new file mode 100644 index 00000000..0d936a3f --- /dev/null +++ b/internal/oci/runtime_path_test.go @@ -0,0 +1,97 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ +package oci + +import ( + "fmt" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestPathRuntimeConstructor(t *testing.T) { + r, err := NewRuntimeForPath("////an/invalid/path") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewRuntimeForPath("/tmp") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewRuntimeForPath("/dev/null") + require.Error(t, err) + require.Nil(t, r) + + r, err = NewRuntimeForPath("/bin/sh") + require.NoError(t, err) + + f, ok := r.(*pathRuntime) + require.True(t, ok) + + require.Equal(t, "/bin/sh", f.path) +} + +func TestPathRuntimeForwardsArgs(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + testCases := []struct { + execRuntimeError error + args []string + }{ + {}, + { + args: []string{"shouldBeReplaced"}, + }, + { + args: []string{"shouldBeReplaced", "arg1"}, + }, + { + execRuntimeError: fmt.Errorf("exec error"), + }, + } + + for _, tc := range testCases { + mockedRuntime := &RuntimeMock{ + ExecFunc: func(strings []string) error { + return tc.execRuntimeError + }, + } + r := pathRuntime{ + logger: logger, + path: "runtime", + execRuntime: mockedRuntime, + } + err := r.Exec(tc.args) + + require.ErrorIs(t, err, tc.execRuntimeError) + + calls := mockedRuntime.ExecCalls() + require.Len(t, calls, 1) + + numArgs := len(tc.args) + if numArgs == 0 { + numArgs = 1 + } + + require.Len(t, calls[0].Strings, numArgs) + require.Equal(t, "runtime", calls[0].Strings[0]) + + if numArgs > 1 { + require.EqualValues(t, tc.args[1:], calls[0].Strings[1:]) + } + } +} diff --git a/internal/oci/runtime_syscall_exec.go b/internal/oci/runtime_syscall_exec.go new file mode 100644 index 00000000..d752776a --- /dev/null +++ b/internal/oci/runtime_syscall_exec.go @@ -0,0 +1,38 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "fmt" + "os" + "syscall" +) + +type syscallExec struct{} + +var _ Runtime = (*syscallExec)(nil) + +func (r syscallExec) Exec(args []string) error { + err := syscall.Exec(args[0], args, os.Environ()) + if err != nil { + return fmt.Errorf("could not exec '%v': %v", args[0], err) + } + + // syscall.Exec is not expected to return. This is an error state regardless of whether + // err is nil or not. + return fmt.Errorf("unexpected return from exec '%v'", args[0]) +} diff --git a/internal/oci/spec.go b/internal/oci/spec.go index 0c3aaf4d..259ce054 100644 --- a/internal/oci/spec.go +++ b/internal/oci/spec.go @@ -17,10 +17,6 @@ package oci import ( - "encoding/json" - "fmt" - "os" - oci "github.com/opencontainers/runtime-spec/specs-go" ) @@ -28,75 +24,12 @@ import ( // error. The intention is that the function would modify the spec in-place. type SpecModifier func(*oci.Spec) error +//go:generate moq -stub -out spec_mock.go . Spec + // Spec defines the operations to be performed on an OCI specification type Spec interface { Load() error Flush() error Modify(SpecModifier) error -} - -type fileSpec struct { - *oci.Spec - path string -} - -var _ Spec = (*fileSpec)(nil) - -// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec. -// This can be used to read from the file, modify the spec, and write to the -// same file. -func NewSpecFromFile(filepath string) Spec { - oci := fileSpec{ - path: filepath, - } - - return &oci -} - -// Load reads the contents of an OCI spec from file to be referenced internally. -// The file is opened "read-only" -func (s *fileSpec) Load() error { - specFile, err := os.Open(s.path) - if err != nil { - return fmt.Errorf("error opening OCI specification file: %v", err) - } - defer specFile.Close() - - decoder := json.NewDecoder(specFile) - - var spec oci.Spec - err = decoder.Decode(&spec) - if err != nil { - return fmt.Errorf("error reading OCI specification from file: %v", err) - } - - s.Spec = &spec - return nil -} - -// Modify applies the specified SpecModifier to the stored OCI specification. -func (s *fileSpec) Modify(f SpecModifier) error { - if s.Spec == nil { - return fmt.Errorf("no spec loaded for modification") - } - return f(s.Spec) -} - -// Flush writes the stored OCI specification to the filepath specifed by the path member. -// The file is truncated upon opening, overwriting any existing contents. -func (s fileSpec) Flush() error { - specFile, err := os.Create(s.path) - if err != nil { - return fmt.Errorf("error opening OCI specification file: %v", err) - } - defer specFile.Close() - - encoder := json.NewEncoder(specFile) - - err = encoder.Encode(s.Spec) - if err != nil { - return fmt.Errorf("error writing OCI specification to file: %v", err) - } - - return nil + LookupEnv(string) (string, bool) } diff --git a/internal/oci/spec_file.go b/internal/oci/spec_file.go new file mode 100644 index 00000000..886e2cb3 --- /dev/null +++ b/internal/oci/spec_file.go @@ -0,0 +1,153 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" + + oci "github.com/opencontainers/runtime-spec/specs-go" +) + +type fileSpec struct { + *oci.Spec + path string +} + +var _ Spec = (*fileSpec)(nil) + +// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the +// application +func NewSpecFromArgs(args []string) (Spec, string, error) { + bundleDir, err := GetBundleDir(args) + if err != nil { + return nil, "", fmt.Errorf("error getting bundle directory: %v", err) + } + + ociSpecPath := GetSpecFilePath(bundleDir) + + ociSpec := NewSpecFromFile(ociSpecPath) + + return ociSpec, bundleDir, nil +} + +// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec. +// This can be used to read from the file, modify the spec, and write to the +// same file. +func NewSpecFromFile(filepath string) Spec { + oci := fileSpec{ + path: filepath, + } + + return &oci +} + +// Load reads the contents of an OCI spec from file to be referenced internally. +// The file is opened "read-only" +func (s *fileSpec) Load() error { + specFile, err := os.Open(s.path) + if err != nil { + return fmt.Errorf("error opening OCI specification file: %v", err) + } + defer specFile.Close() + + return s.loadFrom(specFile) +} + +// loadFrom reads the contents of the OCI spec from the specified io.Reader. +func (s *fileSpec) loadFrom(reader io.Reader) error { + decoder := json.NewDecoder(reader) + + var spec oci.Spec + err := decoder.Decode(&spec) + if err != nil { + return fmt.Errorf("error reading OCI specification: %v", err) + } + + s.Spec = &spec + return nil +} + +// Modify applies the specified SpecModifier to the stored OCI specification. +func (s *fileSpec) Modify(f SpecModifier) error { + if s.Spec == nil { + return fmt.Errorf("no spec loaded for modification") + } + return f(s.Spec) +} + +// Flush writes the stored OCI specification to the filepath specifed by the path member. +// The file is truncated upon opening, overwriting any existing contents. +func (s fileSpec) Flush() error { + if s.Spec == nil { + return fmt.Errorf("no OCI specification loaded") + } + + specFile, err := os.Create(s.path) + if err != nil { + return fmt.Errorf("error opening OCI specification file: %v", err) + } + defer specFile.Close() + + return s.flushTo(specFile) +} + +// flushTo writes the stored OCI specification to the specified io.Writer. +func (s fileSpec) flushTo(writer io.Writer) error { + if s.Spec == nil { + return nil + } + encoder := json.NewEncoder(writer) + + err := encoder.Encode(s.Spec) + if err != nil { + return fmt.Errorf("error writing OCI specification: %v", err) + } + + return nil +} + +// LookupEnv mirrors os.LookupEnv for the OCI specification. It +// retrieves the value of the environment variable named +// by the key. If the variable is present in the environment the +// value (which may be empty) is returned and the boolean is true. +// Otherwise the returned value will be empty and the boolean will +// be false. +func (s fileSpec) LookupEnv(key string) (string, bool) { + if s.Spec == nil || s.Spec.Process == nil { + return "", false + } + + for _, env := range s.Spec.Process.Env { + if !strings.HasPrefix(env, key) { + continue + } + + parts := strings.SplitN(env, "=", 2) + if parts[0] == key { + if len(parts) < 2 { + return "", true + } + return parts[1], true + } + } + + return "", false +} diff --git a/internal/oci/spec_file_test.go b/internal/oci/spec_file_test.go new file mode 100644 index 00000000..fe273824 --- /dev/null +++ b/internal/oci/spec_file_test.go @@ -0,0 +1,252 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package oci + +import ( + "bytes" + "fmt" + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/require" +) + +func TestLookupEnv(t *testing.T) { + const envName = "TEST_ENV" + testCases := []struct { + spec *specs.Spec + expectedValue string + expectedExits bool + }{ + { + // nil spec + spec: nil, + expectedValue: "", + expectedExits: false, + }, + { + // nil process + spec: &specs.Spec{}, + expectedValue: "", + expectedExits: false, + }, + { + // nil env + spec: &specs.Spec{ + Process: &specs.Process{}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // empty env + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // different env set + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same prefix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same suffix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // set blank + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV="}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set no-equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV"}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set value + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something"}}, + }, + expectedValue: "something", + expectedExits: true, + }, + { + // set with equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}}, + }, + expectedValue: "something=somethingelse", + expectedExits: true, + }, + } + + for i, tc := range testCases { + spec := fileSpec{ + Spec: tc.spec, + } + + value, exists := spec.LookupEnv(envName) + + require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc) + require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc) + } +} + +func TestLoadFrom(t *testing.T) { + testCases := []struct { + contents []byte + isError bool + spec *specs.Spec + }{ + { + contents: []byte{}, + isError: true, + }, + { + contents: []byte("{}"), + isError: false, + spec: &specs.Spec{}, + }, + } + + for i, tc := range testCases { + spec := fileSpec{} + err := spec.loadFrom(bytes.NewReader(tc.contents)) + + if tc.isError { + require.Error(t, err, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + } + + if tc.spec == nil { + require.Nil(t, spec.Spec, "%d: %v", i, tc) + } else { + require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc) + } + } +} + +func TestFlushTo(t *testing.T) { + testCases := []struct { + isError bool + spec *specs.Spec + contents string + }{ + { + spec: nil, + }, + { + spec: &specs.Spec{}, + contents: "{\"ociVersion\":\"\"}\n", + }, + } + + for i, tc := range testCases { + buffer := bytes.Buffer{} + + spec := fileSpec{Spec: tc.spec} + err := spec.flushTo(&buffer) + + if tc.isError { + require.Error(t, err, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + } + + require.EqualValues(t, tc.contents, buffer.String(), "%d: %v", i, tc) + } + + // Add a simple test for a writer that returns an error when writing + spec := fileSpec{Spec: &specs.Spec{}} + err := spec.flushTo(errorWriter{}) + require.Error(t, err) +} + +func TestModify(t *testing.T) { + + testCases := []struct { + spec *specs.Spec + modifierError error + }{ + { + spec: nil, + }, + { + spec: &specs.Spec{}, + }, + { + spec: &specs.Spec{}, + modifierError: fmt.Errorf("error in modifier"), + }, + } + + for i, tc := range testCases { + spec := fileSpec{Spec: tc.spec} + + modifier := func(spec *specs.Spec) error { + if tc.modifierError == nil { + spec.Version = "updated" + } + return tc.modifierError + } + + err := spec.Modify(modifier) + + if tc.spec == nil { + require.Error(t, err, "%d: %v", i, tc) + } else if tc.modifierError != nil { + require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc) + require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc) + } + } +} + +// errorWriter implements the io.Writer interface, always returning an error when +// writing. +type errorWriter struct{} + +func (e errorWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("error writing") +} diff --git a/internal/oci/spec_mock.go b/internal/oci/spec_mock.go index 1247adaf..1656552c 100644 --- a/internal/oci/spec_mock.go +++ b/internal/oci/spec_mock.go @@ -1,70 +1,201 @@ -/* -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -*/ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq package oci import ( - oci "github.com/opencontainers/runtime-spec/specs-go" + "sync" ) -// MockSpec provides a simple mock for an OCI spec to be used in testing. -// It also implements the SpecModifier interface. -type MockSpec struct { - *oci.Spec - MockLoad mockFunc - MockFlush mockFunc - MockModify mockFunc -} +// Ensure, that SpecMock does implement Spec. +// If this is not the case, regenerate this file with moq. +var _ Spec = &SpecMock{} -var _ Spec = (*MockSpec)(nil) +// SpecMock is a mock implementation of Spec. +// +// func TestSomethingThatUsesSpec(t *testing.T) { +// +// // make and configure a mocked Spec +// mockedSpec := &SpecMock{ +// FlushFunc: func() error { +// panic("mock out the Flush method") +// }, +// LoadFunc: func() error { +// panic("mock out the Load method") +// }, +// LookupEnvFunc: func(s string) (string, bool) { +// panic("mock out the LookupEnv method") +// }, +// ModifyFunc: func(specModifier SpecModifier) error { +// panic("mock out the Modify method") +// }, +// } +// +// // use mockedSpec in code that requires Spec +// // and then make assertions. +// +// } +type SpecMock struct { + // FlushFunc mocks the Flush method. + FlushFunc func() error -// NewMockSpec constructs a MockSpec to be used in testing as a Spec -func NewMockSpec(spec *oci.Spec, flushResult error, modifyResult error) *MockSpec { - s := MockSpec{ - Spec: spec, - MockFlush: mockFunc{result: flushResult}, - MockModify: mockFunc{result: modifyResult}, + // LoadFunc mocks the Load method. + LoadFunc func() error + + // LookupEnvFunc mocks the LookupEnv method. + LookupEnvFunc func(s string) (string, bool) + + // ModifyFunc mocks the Modify method. + ModifyFunc func(specModifier SpecModifier) error + + // calls tracks calls to the methods. + calls struct { + // Flush holds details about calls to the Flush method. + Flush []struct { + } + // Load holds details about calls to the Load method. + Load []struct { + } + // LookupEnv holds details about calls to the LookupEnv method. + LookupEnv []struct { + // S is the s argument value. + S string + } + // Modify holds details about calls to the Modify method. + Modify []struct { + // SpecModifier is the specModifier argument value. + SpecModifier SpecModifier + } } - - return &s + lockFlush sync.RWMutex + lockLoad sync.RWMutex + lockLookupEnv sync.RWMutex + lockModify sync.RWMutex } -// Load invokes the mocked Load function to return the predefined error / result -func (s *MockSpec) Load() error { - return s.MockLoad.call() +// Flush calls FlushFunc. +func (mock *SpecMock) Flush() error { + callInfo := struct { + }{} + mock.lockFlush.Lock() + mock.calls.Flush = append(mock.calls.Flush, callInfo) + mock.lockFlush.Unlock() + if mock.FlushFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.FlushFunc() } -// Flush invokes the mocked Load function to return the predefined error / result -func (s *MockSpec) Flush() error { - return s.MockFlush.call() +// FlushCalls gets all the calls that were made to Flush. +// Check the length with: +// len(mockedSpec.FlushCalls()) +func (mock *SpecMock) FlushCalls() []struct { +} { + var calls []struct { + } + mock.lockFlush.RLock() + calls = mock.calls.Flush + mock.lockFlush.RUnlock() + return calls } -// Modify applies the specified SpecModifier to the spec and invokes the -// mocked modify function to return the predefined error / result. -func (s *MockSpec) Modify(f SpecModifier) error { - f(s.Spec) - return s.MockModify.call() +// Load calls LoadFunc. +func (mock *SpecMock) Load() error { + callInfo := struct { + }{} + mock.lockLoad.Lock() + mock.calls.Load = append(mock.calls.Load, callInfo) + mock.lockLoad.Unlock() + if mock.LoadFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.LoadFunc() } -type mockFunc struct { - Callcount int - result error +// LoadCalls gets all the calls that were made to Load. +// Check the length with: +// len(mockedSpec.LoadCalls()) +func (mock *SpecMock) LoadCalls() []struct { +} { + var calls []struct { + } + mock.lockLoad.RLock() + calls = mock.calls.Load + mock.lockLoad.RUnlock() + return calls } -func (m *mockFunc) call() error { - m.Callcount++ - return m.result +// LookupEnv calls LookupEnvFunc. +func (mock *SpecMock) LookupEnv(s string) (string, bool) { + callInfo := struct { + S string + }{ + S: s, + } + mock.lockLookupEnv.Lock() + mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo) + mock.lockLookupEnv.Unlock() + if mock.LookupEnvFunc == nil { + var ( + sOut string + bOut bool + ) + return sOut, bOut + } + return mock.LookupEnvFunc(s) +} + +// LookupEnvCalls gets all the calls that were made to LookupEnv. +// Check the length with: +// len(mockedSpec.LookupEnvCalls()) +func (mock *SpecMock) LookupEnvCalls() []struct { + S string +} { + var calls []struct { + S string + } + mock.lockLookupEnv.RLock() + calls = mock.calls.LookupEnv + mock.lockLookupEnv.RUnlock() + return calls +} + +// Modify calls ModifyFunc. +func (mock *SpecMock) Modify(specModifier SpecModifier) error { + callInfo := struct { + SpecModifier SpecModifier + }{ + SpecModifier: specModifier, + } + mock.lockModify.Lock() + mock.calls.Modify = append(mock.calls.Modify, callInfo) + mock.lockModify.Unlock() + if mock.ModifyFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ModifyFunc(specModifier) +} + +// ModifyCalls gets all the calls that were made to Modify. +// Check the length with: +// len(mockedSpec.ModifyCalls()) +func (mock *SpecMock) ModifyCalls() []struct { + SpecModifier SpecModifier +} { + var calls []struct { + SpecModifier SpecModifier + } + mock.lockModify.RLock() + calls = mock.calls.Modify + mock.lockModify.RUnlock() + return calls }