From 4177fddcc4dad5b3c97ffaee4ba3ce90cf79d518 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 31 Jan 2022 15:11:36 +0100 Subject: [PATCH] Import modifying runtime abstraction from experimental runtime This change imports the modifying runtime abstraction from the experimental branch. This encapsulates the checks for whether modification is required, and forwards the loaded spec to the specified modifier. This allows for the same code to be reused when performing more complex modifications. Signed-off-by: Evan Lezar --- cmd/nvidia-container-runtime/main_test.go | 6 +- cmd/nvidia-container-runtime/nvcr.go | 95 +++------- cmd/nvidia-container-runtime/nvcr_test.go | 57 +++--- .../runtime_factory.go | 35 +--- internal/oci/runtime_low_level.go | 11 +- internal/oci/runtime_path.go | 9 +- internal/oci/runtime_path_test.go | 10 +- internal/oci/spec.go | 32 +++- internal/oci/spec_file.go | 96 +++------- internal/oci/spec_file_test.go | 161 +--------------- internal/oci/spec_memory.go | 83 ++++++++ internal/oci/spec_memory_test.go | 178 ++++++++++++++++++ internal/oci/spec_test.go | 2 +- internal/runtime/runtime_modifier.go | 81 ++++++++ internal/runtime/runtime_modifier_test.go | 116 ++++++++++++ 15 files changed, 584 insertions(+), 388 deletions(-) create mode 100644 internal/oci/spec_memory.go create mode 100644 internal/oci/spec_memory_test.go create mode 100644 internal/runtime/runtime_modifier.go create mode 100644 internal/runtime/runtime_modifier_test.go diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index 0aac45ab..847840a0 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -167,11 +167,11 @@ func TestDuplicateHook(t *testing.T) { require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") } -// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for +// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for // testing. func addNVIDIAHook(spec *specs.Spec) error { - r := nvidiaContainerRuntime{logger: logger.Logger} - return r.addNVIDIAHook(spec) + m := addHookModifier{logger: logger.Logger} + return m.Modify(spec) } func (c testConfig) getRuntimeSpec() (specs.Spec, error) { diff --git a/cmd/nvidia-container-runtime/nvcr.go b/cmd/nvidia-container-runtime/nvcr.go index 6c4185b3..3fa90bae 100644 --- a/cmd/nvidia-container-runtime/nvcr.go +++ b/cmd/nvidia-container-runtime/nvcr.go @@ -17,88 +17,41 @@ package main import ( - "fmt" "os" "os/exec" "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/runtime" "github.com/opencontainers/runtime-spec/specs-go" log "github.com/sirupsen/logrus" ) -// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally -// modifying the specified OCI specification before invoking the runtime. -type nvidiaContainerRuntime struct { - logger *log.Logger - runtime oci.Runtime - ociSpec oci.Spec +// newNvidiaContainerRuntime is a constructor for a standard runtime shim. This uses +// a ModifyingRuntimeWrapper to apply the required modifications before execing to the +// specified low-level runtime +func newNvidiaContainerRuntime(logger *log.Logger, lowlevelRuntime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) { + modifier := addHookModifier{logger: logger} + + r := runtime.NewModifyingRuntimeWrapper( + logger, + lowlevelRuntime, + ociSpec, + modifier, + ) + + return r, nil } -var _ oci.Runtime = (*nvidiaContainerRuntime)(nil) - -// newNvidiaContainerRuntime is a constructor for a standard runtime shim. -func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) { - r := nvidiaContainerRuntime{ - logger: logger, - runtime: runtime, - ociSpec: ociSpec, - } - - return &r, nil +// addHookModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a +// prestart hook. If the hook is already present, no modification is made. +type addHookModifier struct { + logger *log.Logger } -// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications -// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then -// forwarded to the underlying runtime's Exec method. -func (r nvidiaContainerRuntime) Exec(args []string) error { - if r.modificationRequired(args) { - err := r.modifyOCISpec() - if err != nil { - return fmt.Errorf("error modifying OCI spec: %v", err) - } - } - - r.logger.Println("Forwarding command to runtime") - return r.runtime.Exec(args) -} - -// modificationRequired checks the intput arguments to determine whether a modification -// to the OCI spec is required. -func (r nvidiaContainerRuntime) modificationRequired(args []string) bool { - if oci.HasCreateSubcommand(args) { - r.logger.Infof("'create' command detected; modification required") - return true - } - - r.logger.Infof("No modification required") - return false -} - -// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime -// struct. The spec is modified in-place and written to the same file as the input after -// modifcationas are applied. -func (r nvidiaContainerRuntime) modifyOCISpec() error { - err := r.ociSpec.Load() - if err != nil { - return fmt.Errorf("error loading OCI specification for modification: %v", err) - } - - err = r.ociSpec.Modify(r.addNVIDIAHook) - if err != nil { - return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err) - } - - err = r.ociSpec.Flush() - if err != nil { - return fmt.Errorf("error writing modified OCI specification: %v", err) - } - return nil -} - -// addNVIDIAHook modifies the specified OCI specification in-place, inserting a -// prestart hook. -func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error { +// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook +// as a prestart hook. +func (m addHookModifier) Modify(spec *specs.Spec) error { path, err := exec.LookPath("nvidia-container-runtime-hook") if err != nil { path = hookDefaultFilePath @@ -108,7 +61,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error { } } - r.logger.Printf("prestart hook path: %s\n", path) + m.logger.Printf("prestart hook path: %s\n", path) args := []string{path} if spec.Hooks == nil { @@ -118,7 +71,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error { if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") { continue } - r.logger.Println("existing nvidia prestart hook in OCI spec file") + m.logger.Println("existing nvidia prestart hook in OCI spec file") return nil } } diff --git a/cmd/nvidia-container-runtime/nvcr_test.go b/cmd/nvidia-container-runtime/nvcr_test.go index 16e1b447..0fc1348d 100644 --- a/cmd/nvidia-container-runtime/nvcr_test.go +++ b/cmd/nvidia-container-runtime/nvcr_test.go @@ -29,9 +29,8 @@ import ( func TestAddNvidiaHook(t *testing.T) { logger, logHook := testlog.NewNullLogger() - shim := nvidiaContainerRuntime{ - logger: logger, - } + + mockRuntime := &oci.RuntimeMock{} testCases := []struct { spec *specs.Spec @@ -75,7 +74,16 @@ func TestAddNvidiaHook(t *testing.T) { numPrestartHooks = len(tc.spec.Hooks.Prestart) } - err := shim.addNVIDIAHook(tc.spec) + shim, err := newNvidiaContainerRuntime( + logger, + mockRuntime, + oci.NewMemorySpec(tc.spec), + ) + + require.NoError(t, err) + + err = shim.Exec([]string{"runtime", "create"}) + require.NoError(t, err) if tc.errorPrefix == "" { require.NoErrorf(t, err, "%d: %v", i, tc) @@ -106,45 +114,39 @@ func TestAddNvidiaHook(t *testing.T) { func TestNvidiaContainerRuntime(t *testing.T) { logger, hook := testlog.NewNullLogger() + mockRuntime := &oci.RuntimeMock{} + testCases := []struct { - shim nvidiaContainerRuntime shouldModify bool args []string modifyError error writeError error }{ { - shim: nvidiaContainerRuntime{}, shouldModify: false, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, shouldModify: true, }, { - shim: nvidiaContainerRuntime{}, args: []string{"--bundle=create"}, shouldModify: false, }, { - shim: nvidiaContainerRuntime{}, args: []string{"--bundle", "create"}, shouldModify: false, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, shouldModify: true, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, modifyError: fmt.Errorf("error modifying"), shouldModify: true, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, writeError: fmt.Errorf("error writing"), shouldModify: true, @@ -152,10 +154,8 @@ func TestNvidiaContainerRuntime(t *testing.T) { } for i, tc := range testCases { - tc.shim.logger = logger hook.Reset() - - ociMock := &oci.SpecMock{ + specMock := &oci.SpecMock{ ModifyFunc: func(specModifier oci.SpecModifier) error { return tc.modifyError }, @@ -163,12 +163,11 @@ func TestNvidiaContainerRuntime(t *testing.T) { return tc.writeError }, } - require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc) - tc.shim.ociSpec = ociMock - tc.shim.runtime = &MockShim{} + shim, err := newNvidiaContainerRuntime(logger, mockRuntime, specMock) + require.NoError(t, err) - err := tc.shim.Exec(tc.args) + err = shim.Exec(tc.args) if tc.modifyError != nil || tc.writeError != nil { require.Error(t, err, "%d: %v", i, tc) } else { @@ -176,28 +175,16 @@ func TestNvidiaContainerRuntime(t *testing.T) { } if tc.shouldModify { - require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc) + require.Equal(t, 1, len(specMock.ModifyCalls()), "%d: %v", i, tc) } else { - require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc) + require.Equal(t, 0, len(specMock.ModifyCalls()), "%d: %v", i, tc) } writeExpected := tc.shouldModify && tc.modifyError == nil if writeExpected { - require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc) + require.Equal(t, 1, len(specMock.FlushCalls()), "%d: %v", i, tc) } else { - require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc) + require.Equal(t, 0, len(specMock.FlushCalls()), "%d: %v", i, tc) } } } - -type MockShim struct { - called bool - args []string - returnError error -} - -func (m *MockShim) Exec(args []string) error { - m.called = true - m.args = args - return m.returnError -} diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go index 0af37f6e..3e3cd353 100644 --- a/cmd/nvidia-container-runtime/runtime_factory.go +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -23,52 +23,27 @@ import ( ) const ( - ociSpecFileName = "config.json" dockerRuncExecutableName = "docker-runc" runcExecutableName = "runc" ) // newRuntime is a factory method that constructs a runtime based on the selected configuration. func newRuntime(argv []string) (oci.Runtime, error) { - ociSpec, err := newOCISpec(argv) + ociSpec, err := oci.NewSpec(logger.Logger, argv) if err != nil { return nil, fmt.Errorf("error constructing OCI specification: %v", err) } - runc, err := newRuncRuntime() + lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName} + lowLevelRuntime, err := oci.NewLowLevelRuntime(logger.Logger, lowLevelRuntimeCandidates) if err != nil { - return nil, fmt.Errorf("error constructing runc runtime: %v", err) + return nil, fmt.Errorf("error constructing low-level runtime: %v", err) } - r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec) + r, err := newNvidiaContainerRuntime(logger.Logger, lowLevelRuntime, ociSpec) if err != nil { return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err) } return r, nil } - -// newOCISpec constructs an OCI spec for the provided arguments -func newOCISpec(argv []string) (oci.Spec, error) { - bundleDir, err := oci.GetBundleDir(argv) - if err != nil { - return nil, fmt.Errorf("error parsing command line arguments: %v", err) - } - logger.Infof("Using bundle directory: %v", bundleDir) - - ociSpecPath := oci.GetSpecFilePath(bundleDir) - logger.Infof("Using OCI specification file path: %v", ociSpecPath) - - ociSpec := oci.NewSpecFromFile(ociSpecPath) - - return ociSpec, nil -} - -// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime -func newRuncRuntime() (oci.Runtime, error) { - return oci.NewLowLevelRuntimeWithLogger( - logger.Logger, - dockerRuncExecutableName, - runcExecutableName, - ) -} diff --git a/internal/oci/runtime_low_level.go b/internal/oci/runtime_low_level.go index 6363d2ec..ac5aea59 100644 --- a/internal/oci/runtime_low_level.go +++ b/internal/oci/runtime_low_level.go @@ -25,19 +25,14 @@ import ( // NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable. // The executable specified is taken from the list of supplied candidates, with the first match -// present in the PATH being selected. -func NewLowLevelRuntime(candidates ...string) (Runtime, error) { - return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...) -} - -// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger. -func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) { +// present in the PATH being selected. A logger is also specified. +func NewLowLevelRuntime(logger *log.Logger, candidates []string) (Runtime, error) { runtimePath, err := findRuntime(logger, candidates) if err != nil { return nil, fmt.Errorf("error locating runtime: %v", err) } - return NewRuntimeForPathWithLogger(logger, runtimePath) + return NewRuntimeForPath(logger, runtimePath) } // findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH. diff --git a/internal/oci/runtime_path.go b/internal/oci/runtime_path.go index abf225b2..102e84f6 100644 --- a/internal/oci/runtime_path.go +++ b/internal/oci/runtime_path.go @@ -34,13 +34,8 @@ type pathRuntime struct { var _ Runtime = (*pathRuntime)(nil) -// NewRuntimeForPath creates a Runtime for the specified path with the standard logger -func NewRuntimeForPath(path string) (Runtime, error) { - return NewRuntimeForPathWithLogger(log.StandardLogger(), path) -} - -// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path -func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) { +// NewRuntimeForPath creates a Runtime for the specified logger and path +func NewRuntimeForPath(logger *log.Logger, path string) (Runtime, error) { info, err := os.Stat(path) if err != nil { return nil, fmt.Errorf("invalid path '%v': %v", path, err) diff --git a/internal/oci/runtime_path_test.go b/internal/oci/runtime_path_test.go index 0d936a3f..ac247a1c 100644 --- a/internal/oci/runtime_path_test.go +++ b/internal/oci/runtime_path_test.go @@ -24,19 +24,21 @@ import ( ) func TestPathRuntimeConstructor(t *testing.T) { - r, err := NewRuntimeForPath("////an/invalid/path") + logger, _ := testlog.NewNullLogger() + + r, err := NewRuntimeForPath(logger, "////an/invalid/path") require.Error(t, err) require.Nil(t, r) - r, err = NewRuntimeForPath("/tmp") + r, err = NewRuntimeForPath(logger, "/tmp") require.Error(t, err) require.Nil(t, r) - r, err = NewRuntimeForPath("/dev/null") + r, err = NewRuntimeForPath(logger, "/dev/null") require.Error(t, err) require.Nil(t, r) - r, err = NewRuntimeForPath("/bin/sh") + r, err = NewRuntimeForPath(logger, "/bin/sh") require.NoError(t, err) f, ok := r.(*pathRuntime) diff --git a/internal/oci/spec.go b/internal/oci/spec.go index 259ce054..ba7e7cee 100644 --- a/internal/oci/spec.go +++ b/internal/oci/spec.go @@ -17,15 +17,20 @@ package oci import ( - oci "github.com/opencontainers/runtime-spec/specs-go" + "fmt" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" ) -// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an -// error. The intention is that the function would modify the spec in-place. -type SpecModifier func(*oci.Spec) error +// SpecModifier defines an interace for modifying a (raw) OCI spec +type SpecModifier interface { + // Modify is a method that accepts a pointer to an OCI Srec and returns an + // error. The intention is that the function would modify the spec in-place. + Modify(*specs.Spec) error +} //go:generate moq -stub -out spec_mock.go . Spec - // Spec defines the operations to be performed on an OCI specification type Spec interface { Load() error @@ -33,3 +38,20 @@ type Spec interface { Modify(SpecModifier) error LookupEnv(string) (string, bool) } + +// NewSpec creates fileSpec based on the command line arguments passed to the +// application using the specified logger. +func NewSpec(logger *logrus.Logger, args []string) (Spec, error) { + bundleDir, err := GetBundleDir(args) + if err != nil { + return nil, fmt.Errorf("error getting bundle directory: %v", err) + } + logger.Infof("Using bundle directory: %v", bundleDir) + + ociSpecPath := GetSpecFilePath(bundleDir) + logger.Infof("Using OCI specification file path: %v", ociSpecPath) + + ociSpec := NewFileSpec(ociSpecPath) + + return ociSpec, nil +} diff --git a/internal/oci/spec_file.go b/internal/oci/spec_file.go index 886e2cb3..ff0cbb46 100644 --- a/internal/oci/spec_file.go +++ b/internal/oci/spec_file.go @@ -21,37 +21,21 @@ import ( "fmt" "io" "os" - "strings" - oci "github.com/opencontainers/runtime-spec/specs-go" + "github.com/opencontainers/runtime-spec/specs-go" ) type fileSpec struct { - *oci.Spec + memorySpec path string } var _ Spec = (*fileSpec)(nil) -// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the -// application -func NewSpecFromArgs(args []string) (Spec, string, error) { - bundleDir, err := GetBundleDir(args) - if err != nil { - return nil, "", fmt.Errorf("error getting bundle directory: %v", err) - } - - ociSpecPath := GetSpecFilePath(bundleDir) - - ociSpec := NewSpecFromFile(ociSpecPath) - - return ociSpec, bundleDir, nil -} - -// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec. +// NewFileSpec creates an object that encapsulates a file-backed OCI spec. // This can be used to read from the file, modify the spec, and write to the // same file. -func NewSpecFromFile(filepath string) Spec { +func NewFileSpec(filepath string) Spec { oci := fileSpec{ path: filepath, } @@ -68,29 +52,31 @@ func (s *fileSpec) Load() error { } defer specFile.Close() - return s.loadFrom(specFile) -} - -// loadFrom reads the contents of the OCI spec from the specified io.Reader. -func (s *fileSpec) loadFrom(reader io.Reader) error { - decoder := json.NewDecoder(reader) - - var spec oci.Spec - err := decoder.Decode(&spec) + spec, err := loadFrom(specFile) if err != nil { - return fmt.Errorf("error reading OCI specification: %v", err) + return fmt.Errorf("error loading OCI specification from file: %v", err) } - - s.Spec = &spec + s.Spec = spec return nil } -// Modify applies the specified SpecModifier to the stored OCI specification. -func (s *fileSpec) Modify(f SpecModifier) error { - if s.Spec == nil { - return fmt.Errorf("no spec loaded for modification") +// loadFrom reads the contents of the OCI spec from the specified io.Reader. +func loadFrom(reader io.Reader) (*specs.Spec, error) { + decoder := json.NewDecoder(reader) + + var spec specs.Spec + + err := decoder.Decode(&spec) + if err != nil { + return nil, fmt.Errorf("error reading OCI specification: %v", err) } - return f(s.Spec) + + return &spec, nil +} + +// Modify applies the specified SpecModifier to the stored OCI specification. +func (s *fileSpec) Modify(m SpecModifier) error { + return s.memorySpec.Modify(m) } // Flush writes the stored OCI specification to the filepath specifed by the path member. @@ -106,48 +92,20 @@ func (s fileSpec) Flush() error { } defer specFile.Close() - return s.flushTo(specFile) + return flushTo(s.Spec, specFile) } // flushTo writes the stored OCI specification to the specified io.Writer. -func (s fileSpec) flushTo(writer io.Writer) error { - if s.Spec == nil { +func flushTo(spec *specs.Spec, writer io.Writer) error { + if spec == nil { return nil } encoder := json.NewEncoder(writer) - err := encoder.Encode(s.Spec) + err := encoder.Encode(spec) if err != nil { return fmt.Errorf("error writing OCI specification: %v", err) } return nil } - -// LookupEnv mirrors os.LookupEnv for the OCI specification. It -// retrieves the value of the environment variable named -// by the key. If the variable is present in the environment the -// value (which may be empty) is returned and the boolean is true. -// Otherwise the returned value will be empty and the boolean will -// be false. -func (s fileSpec) LookupEnv(key string) (string, bool) { - if s.Spec == nil || s.Spec.Process == nil { - return "", false - } - - for _, env := range s.Spec.Process.Env { - if !strings.HasPrefix(env, key) { - continue - } - - parts := strings.SplitN(env, "=", 2) - if parts[0] == key { - if len(parts) < 2 { - return "", true - } - return parts[1], true - } - } - - return "", false -} diff --git a/internal/oci/spec_file_test.go b/internal/oci/spec_file_test.go index fe273824..94dfb3b3 100644 --- a/internal/oci/spec_file_test.go +++ b/internal/oci/spec_file_test.go @@ -25,111 +25,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestLookupEnv(t *testing.T) { - const envName = "TEST_ENV" - testCases := []struct { - spec *specs.Spec - expectedValue string - expectedExits bool - }{ - { - // nil spec - spec: nil, - expectedValue: "", - expectedExits: false, - }, - { - // nil process - spec: &specs.Spec{}, - expectedValue: "", - expectedExits: false, - }, - { - // nil env - spec: &specs.Spec{ - Process: &specs.Process{}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // empty env - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // different env set - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // same prefix - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // same suffix - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // set blank - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV="}}, - }, - expectedValue: "", - expectedExits: true, - }, - { - // set no-equals - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV"}}, - }, - expectedValue: "", - expectedExits: true, - }, - { - // set value - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV=something"}}, - }, - expectedValue: "something", - expectedExits: true, - }, - { - // set with equals - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}}, - }, - expectedValue: "something=somethingelse", - expectedExits: true, - }, - } - - for i, tc := range testCases { - spec := fileSpec{ - Spec: tc.spec, - } - - value, exists := spec.LookupEnv(envName) - - require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc) - require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc) - } -} - func TestLoadFrom(t *testing.T) { testCases := []struct { contents []byte @@ -148,8 +43,8 @@ func TestLoadFrom(t *testing.T) { } for i, tc := range testCases { - spec := fileSpec{} - err := spec.loadFrom(bytes.NewReader(tc.contents)) + var spec *specs.Spec + spec, err := loadFrom(bytes.NewReader(tc.contents)) if tc.isError { require.Error(t, err, "%d: %v", i, tc) @@ -158,9 +53,9 @@ func TestLoadFrom(t *testing.T) { } if tc.spec == nil { - require.Nil(t, spec.Spec, "%d: %v", i, tc) + require.Nil(t, spec, "%d: %v", i, tc) } else { - require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc) + require.EqualValues(t, tc.spec, spec, "%d: %v", i, tc) } } } @@ -183,8 +78,7 @@ func TestFlushTo(t *testing.T) { for i, tc := range testCases { buffer := bytes.Buffer{} - spec := fileSpec{Spec: tc.spec} - err := spec.flushTo(&buffer) + err := flushTo(tc.spec, &buffer) if tc.isError { require.Error(t, err, "%d: %v", i, tc) @@ -196,53 +90,10 @@ func TestFlushTo(t *testing.T) { } // Add a simple test for a writer that returns an error when writing - spec := fileSpec{Spec: &specs.Spec{}} - err := spec.flushTo(errorWriter{}) + err := flushTo(&specs.Spec{}, errorWriter{}) require.Error(t, err) } -func TestModify(t *testing.T) { - - testCases := []struct { - spec *specs.Spec - modifierError error - }{ - { - spec: nil, - }, - { - spec: &specs.Spec{}, - }, - { - spec: &specs.Spec{}, - modifierError: fmt.Errorf("error in modifier"), - }, - } - - for i, tc := range testCases { - spec := fileSpec{Spec: tc.spec} - - modifier := func(spec *specs.Spec) error { - if tc.modifierError == nil { - spec.Version = "updated" - } - return tc.modifierError - } - - err := spec.Modify(modifier) - - if tc.spec == nil { - require.Error(t, err, "%d: %v", i, tc) - } else if tc.modifierError != nil { - require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc) - require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc) - } else { - require.NoError(t, err, "%d: %v", i, tc) - require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc) - } - } -} - // errorWriter implements the io.Writer interface, always returning an error when // writing. type errorWriter struct{} diff --git a/internal/oci/spec_memory.go b/internal/oci/spec_memory.go new file mode 100644 index 00000000..ce94447e --- /dev/null +++ b/internal/oci/spec_memory.go @@ -0,0 +1,83 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package oci + +import ( + "fmt" + "strings" + + "github.com/opencontainers/runtime-spec/specs-go" +) + +type memorySpec struct { + *specs.Spec +} + +// NewMemorySpec creates a Spec instance from the specified OCI spec +func NewMemorySpec(spec *specs.Spec) Spec { + s := memorySpec{ + Spec: spec, + } + + return &s +} + +// Load is a no-op for the memorySpec spec +func (s *memorySpec) Load() error { + return nil +} + +// Flush is a no-op for the memorySpec spec +func (s *memorySpec) Flush() error { + return nil +} + +// Modify applies the specified SpecModifier to the stored OCI specification. +func (s *memorySpec) Modify(m SpecModifier) error { + if s.Spec == nil { + return fmt.Errorf("cannot modify nil spec") + } + return m.Modify(s.Spec) +} + +// LookupEnv mirrors os.LookupEnv for the OCI specification. It +// retrieves the value of the environment variable named +// by the key. If the variable is present in the environment the +// value (which may be empty) is returned and the boolean is true. +// Otherwise the returned value will be empty and the boolean will +// be false. +func (s memorySpec) LookupEnv(key string) (string, bool) { + if s.Spec == nil || s.Spec.Process == nil { + return "", false + } + + for _, env := range s.Spec.Process.Env { + if !strings.HasPrefix(env, key) { + continue + } + + parts := strings.SplitN(env, "=", 2) + if parts[0] == key { + if len(parts) < 2 { + return "", true + } + return parts[1], true + } + } + + return "", false +} diff --git a/internal/oci/spec_memory_test.go b/internal/oci/spec_memory_test.go new file mode 100644 index 00000000..a43944a4 --- /dev/null +++ b/internal/oci/spec_memory_test.go @@ -0,0 +1,178 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package oci + +import ( + "fmt" + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/require" +) + +func TestLookupEnv(t *testing.T) { + const envName = "TEST_ENV" + testCases := []struct { + spec *specs.Spec + expectedValue string + expectedExits bool + }{ + { + // nil spec + spec: nil, + expectedValue: "", + expectedExits: false, + }, + { + // nil process + spec: &specs.Spec{}, + expectedValue: "", + expectedExits: false, + }, + { + // nil env + spec: &specs.Spec{ + Process: &specs.Process{}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // empty env + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // different env set + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same prefix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same suffix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // set blank + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV="}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set no-equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV"}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set value + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something"}}, + }, + expectedValue: "something", + expectedExits: true, + }, + { + // set with equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}}, + }, + expectedValue: "something=somethingelse", + expectedExits: true, + }, + } + + for i, tc := range testCases { + spec := memorySpec{ + Spec: tc.spec, + } + + value, exists := spec.LookupEnv(envName) + + require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc) + require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc) + } +} + +func TestModify(t *testing.T) { + + testCases := []struct { + spec *specs.Spec + modifierError error + }{ + { + spec: nil, + }, + { + spec: &specs.Spec{}, + }, + { + spec: &specs.Spec{}, + modifierError: fmt.Errorf("error in modifier"), + }, + } + + for i, tc := range testCases { + spec := NewMemorySpec(tc.spec).(*memorySpec) + + err := spec.Modify(modifier{tc.modifierError}) + + if tc.spec == nil { + require.Error(t, err, "%d: %v", i, tc) + } else if tc.modifierError != nil { + require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc) + require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc) + } + } +} + +// TODO: Ideally we would generated a mock for the SpecModifer too. This causes +// an import cycle and we define a local type as a work-around. +type modifier struct { + modifierError error +} + +func (m modifier) Modify(spec *specs.Spec) error { + if m.modifierError == nil { + spec.Version = "updated" + } + return m.modifierError +} diff --git a/internal/oci/spec_test.go b/internal/oci/spec_test.go index 03d2d301..f17f3d3c 100644 --- a/internal/oci/spec_test.go +++ b/internal/oci/spec_test.go @@ -20,7 +20,7 @@ func TestMaintainSpec(t *testing.T) { for _, f := range files { inputSpecPath := filepath.Join(moduleRoot, "test/input", f) - spec := NewSpecFromFile(inputSpecPath).(*fileSpec) + spec := NewFileSpec(inputSpecPath).(*fileSpec) spec.Load() diff --git a/internal/runtime/runtime_modifier.go b/internal/runtime/runtime_modifier.go new file mode 100644 index 00000000..86accba0 --- /dev/null +++ b/internal/runtime/runtime_modifier.go @@ -0,0 +1,81 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package runtime + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + log "github.com/sirupsen/logrus" +) + +type modifyingRuntimeWrapper struct { + logger *log.Logger + runtime oci.Runtime + ociSpec oci.Spec + modifier oci.SpecModifier +} + +var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil) + +// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification +// before invoking the wrapped runtime. +func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime { + rt := modifyingRuntimeWrapper{ + logger: logger, + runtime: runtime, + ociSpec: spec, + modifier: modifier, + } + return &rt +} + +// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing +// into the wrapped runtime. +func (r *modifyingRuntimeWrapper) Exec(args []string) error { + if oci.HasCreateSubcommand(args) { + err := r.modify() + if err != nil { + return fmt.Errorf("could not apply required modification to OCI specification: %v", err) + } + r.logger.Infof("Applied required modification to OCI specification") + } else { + r.logger.Infof("No modification of OCI specification required") + } + + r.logger.Infof("Forwarding command to runtime") + return r.runtime.Exec(args) +} + +// modify loads, modifies, and flushes the OCI specification using the defined Modifier +func (r *modifyingRuntimeWrapper) modify() error { + err := r.ociSpec.Load() + if err != nil { + return fmt.Errorf("error loading OCI specification for modification: %v", err) + } + + err = r.ociSpec.Modify(r.modifier) + if err != nil { + return fmt.Errorf("error modifying OCI spec: %v", err) + } + + err = r.ociSpec.Flush() + if err != nil { + return fmt.Errorf("error writing modified OCI specification: %v", err) + } + return nil +} diff --git a/internal/runtime/runtime_modifier_test.go b/internal/runtime/runtime_modifier_test.go new file mode 100644 index 00000000..5ab77a04 --- /dev/null +++ b/internal/runtime/runtime_modifier_test.go @@ -0,0 +1,116 @@ +/* +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package runtime + +import ( + "fmt" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestExec(t *testing.T) { + logger, hook := testlog.NewNullLogger() + + testCases := []struct { + description string + shouldModify bool + shouldFlush bool + shouldForward bool + args []string + modifyError error + writeError error + }{ + { + description: "no args forwards", + shouldModify: false, + shouldFlush: false, + shouldForward: true, + }, + { + description: "create modifies", + args: []string{"create"}, + shouldModify: true, + shouldFlush: true, + shouldForward: true, + }, + { + description: "modify error does not write or forward", + args: []string{"create"}, + modifyError: fmt.Errorf("error modifying"), + shouldModify: true, + shouldFlush: false, + shouldForward: false, + }, + { + description: "write error does not forward", + args: []string{"create"}, + writeError: fmt.Errorf("error writing"), + shouldModify: true, + shouldFlush: true, + shouldForward: false, + }, + } + + for _, tc := range testCases { + hook.Reset() + + t.Run(tc.description, func(t *testing.T) { + runtimeMock := &oci.RuntimeMock{} + specMock := &oci.SpecMock{ + ModifyFunc: func(specModifier oci.SpecModifier) error { + return tc.modifyError + }, + FlushFunc: func() error { + return tc.writeError + }, + } + + shim := NewModifyingRuntimeWrapper( + logger, + runtimeMock, + specMock, + // TODO: We should test the interactions with the SpecModifier too + nil) + + err := shim.Exec(tc.args) + if tc.modifyError != nil || tc.writeError != nil { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + if tc.shouldModify { + require.Equal(t, 1, len(specMock.ModifyCalls())) + } else { + require.Equal(t, 0, len(specMock.ModifyCalls())) + } + if tc.shouldFlush { + require.Equal(t, 1, len(specMock.FlushCalls())) + } else { + require.Equal(t, 0, len(specMock.FlushCalls())) + } + if tc.shouldForward { + require.Equal(t, 1, len(runtimeMock.ExecCalls())) + } else { + require.Equal(t, 0, len(runtimeMock.ExecCalls())) + } + }) + } +}