diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index 0aac45ab..847840a0 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -167,11 +167,11 @@ func TestDuplicateHook(t *testing.T) { require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") } -// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for +// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for // testing. func addNVIDIAHook(spec *specs.Spec) error { - r := nvidiaContainerRuntime{logger: logger.Logger} - return r.addNVIDIAHook(spec) + m := addHookModifier{logger: logger.Logger} + return m.Modify(spec) } func (c testConfig) getRuntimeSpec() (specs.Spec, error) { diff --git a/cmd/nvidia-container-runtime/nvcr.go b/cmd/nvidia-container-runtime/nvcr.go index 6c4185b3..3fa90bae 100644 --- a/cmd/nvidia-container-runtime/nvcr.go +++ b/cmd/nvidia-container-runtime/nvcr.go @@ -17,88 +17,41 @@ package main import ( - "fmt" "os" "os/exec" "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/runtime" "github.com/opencontainers/runtime-spec/specs-go" log "github.com/sirupsen/logrus" ) -// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally -// modifying the specified OCI specification before invoking the runtime. -type nvidiaContainerRuntime struct { - logger *log.Logger - runtime oci.Runtime - ociSpec oci.Spec +// newNvidiaContainerRuntime is a constructor for a standard runtime shim. This uses +// a ModifyingRuntimeWrapper to apply the required modifications before execing to the +// specified low-level runtime +func newNvidiaContainerRuntime(logger *log.Logger, lowlevelRuntime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) { + modifier := addHookModifier{logger: logger} + + r := runtime.NewModifyingRuntimeWrapper( + logger, + lowlevelRuntime, + ociSpec, + modifier, + ) + + return r, nil } -var _ oci.Runtime = (*nvidiaContainerRuntime)(nil) - -// newNvidiaContainerRuntime is a constructor for a standard runtime shim. -func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) { - r := nvidiaContainerRuntime{ - logger: logger, - runtime: runtime, - ociSpec: ociSpec, - } - - return &r, nil +// addHookModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a +// prestart hook. If the hook is already present, no modification is made. +type addHookModifier struct { + logger *log.Logger } -// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications -// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then -// forwarded to the underlying runtime's Exec method. -func (r nvidiaContainerRuntime) Exec(args []string) error { - if r.modificationRequired(args) { - err := r.modifyOCISpec() - if err != nil { - return fmt.Errorf("error modifying OCI spec: %v", err) - } - } - - r.logger.Println("Forwarding command to runtime") - return r.runtime.Exec(args) -} - -// modificationRequired checks the intput arguments to determine whether a modification -// to the OCI spec is required. -func (r nvidiaContainerRuntime) modificationRequired(args []string) bool { - if oci.HasCreateSubcommand(args) { - r.logger.Infof("'create' command detected; modification required") - return true - } - - r.logger.Infof("No modification required") - return false -} - -// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime -// struct. The spec is modified in-place and written to the same file as the input after -// modifcationas are applied. -func (r nvidiaContainerRuntime) modifyOCISpec() error { - err := r.ociSpec.Load() - if err != nil { - return fmt.Errorf("error loading OCI specification for modification: %v", err) - } - - err = r.ociSpec.Modify(r.addNVIDIAHook) - if err != nil { - return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err) - } - - err = r.ociSpec.Flush() - if err != nil { - return fmt.Errorf("error writing modified OCI specification: %v", err) - } - return nil -} - -// addNVIDIAHook modifies the specified OCI specification in-place, inserting a -// prestart hook. -func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error { +// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook +// as a prestart hook. +func (m addHookModifier) Modify(spec *specs.Spec) error { path, err := exec.LookPath("nvidia-container-runtime-hook") if err != nil { path = hookDefaultFilePath @@ -108,7 +61,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error { } } - r.logger.Printf("prestart hook path: %s\n", path) + m.logger.Printf("prestart hook path: %s\n", path) args := []string{path} if spec.Hooks == nil { @@ -118,7 +71,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error { if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") { continue } - r.logger.Println("existing nvidia prestart hook in OCI spec file") + m.logger.Println("existing nvidia prestart hook in OCI spec file") return nil } } diff --git a/cmd/nvidia-container-runtime/nvcr_test.go b/cmd/nvidia-container-runtime/nvcr_test.go index 16e1b447..0fc1348d 100644 --- a/cmd/nvidia-container-runtime/nvcr_test.go +++ b/cmd/nvidia-container-runtime/nvcr_test.go @@ -29,9 +29,8 @@ import ( func TestAddNvidiaHook(t *testing.T) { logger, logHook := testlog.NewNullLogger() - shim := nvidiaContainerRuntime{ - logger: logger, - } + + mockRuntime := &oci.RuntimeMock{} testCases := []struct { spec *specs.Spec @@ -75,7 +74,16 @@ func TestAddNvidiaHook(t *testing.T) { numPrestartHooks = len(tc.spec.Hooks.Prestart) } - err := shim.addNVIDIAHook(tc.spec) + shim, err := newNvidiaContainerRuntime( + logger, + mockRuntime, + oci.NewMemorySpec(tc.spec), + ) + + require.NoError(t, err) + + err = shim.Exec([]string{"runtime", "create"}) + require.NoError(t, err) if tc.errorPrefix == "" { require.NoErrorf(t, err, "%d: %v", i, tc) @@ -106,45 +114,39 @@ func TestAddNvidiaHook(t *testing.T) { func TestNvidiaContainerRuntime(t *testing.T) { logger, hook := testlog.NewNullLogger() + mockRuntime := &oci.RuntimeMock{} + testCases := []struct { - shim nvidiaContainerRuntime shouldModify bool args []string modifyError error writeError error }{ { - shim: nvidiaContainerRuntime{}, shouldModify: false, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, shouldModify: true, }, { - shim: nvidiaContainerRuntime{}, args: []string{"--bundle=create"}, shouldModify: false, }, { - shim: nvidiaContainerRuntime{}, args: []string{"--bundle", "create"}, shouldModify: false, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, shouldModify: true, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, modifyError: fmt.Errorf("error modifying"), shouldModify: true, }, { - shim: nvidiaContainerRuntime{}, args: []string{"create"}, writeError: fmt.Errorf("error writing"), shouldModify: true, @@ -152,10 +154,8 @@ func TestNvidiaContainerRuntime(t *testing.T) { } for i, tc := range testCases { - tc.shim.logger = logger hook.Reset() - - ociMock := &oci.SpecMock{ + specMock := &oci.SpecMock{ ModifyFunc: func(specModifier oci.SpecModifier) error { return tc.modifyError }, @@ -163,12 +163,11 @@ func TestNvidiaContainerRuntime(t *testing.T) { return tc.writeError }, } - require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc) - tc.shim.ociSpec = ociMock - tc.shim.runtime = &MockShim{} + shim, err := newNvidiaContainerRuntime(logger, mockRuntime, specMock) + require.NoError(t, err) - err := tc.shim.Exec(tc.args) + err = shim.Exec(tc.args) if tc.modifyError != nil || tc.writeError != nil { require.Error(t, err, "%d: %v", i, tc) } else { @@ -176,28 +175,16 @@ func TestNvidiaContainerRuntime(t *testing.T) { } if tc.shouldModify { - require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc) + require.Equal(t, 1, len(specMock.ModifyCalls()), "%d: %v", i, tc) } else { - require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc) + require.Equal(t, 0, len(specMock.ModifyCalls()), "%d: %v", i, tc) } writeExpected := tc.shouldModify && tc.modifyError == nil if writeExpected { - require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc) + require.Equal(t, 1, len(specMock.FlushCalls()), "%d: %v", i, tc) } else { - require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc) + require.Equal(t, 0, len(specMock.FlushCalls()), "%d: %v", i, tc) } } } - -type MockShim struct { - called bool - args []string - returnError error -} - -func (m *MockShim) Exec(args []string) error { - m.called = true - m.args = args - return m.returnError -} diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go index 0af37f6e..3e3cd353 100644 --- a/cmd/nvidia-container-runtime/runtime_factory.go +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -23,52 +23,27 @@ import ( ) const ( - ociSpecFileName = "config.json" dockerRuncExecutableName = "docker-runc" runcExecutableName = "runc" ) // newRuntime is a factory method that constructs a runtime based on the selected configuration. func newRuntime(argv []string) (oci.Runtime, error) { - ociSpec, err := newOCISpec(argv) + ociSpec, err := oci.NewSpec(logger.Logger, argv) if err != nil { return nil, fmt.Errorf("error constructing OCI specification: %v", err) } - runc, err := newRuncRuntime() + lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName} + lowLevelRuntime, err := oci.NewLowLevelRuntime(logger.Logger, lowLevelRuntimeCandidates) if err != nil { - return nil, fmt.Errorf("error constructing runc runtime: %v", err) + return nil, fmt.Errorf("error constructing low-level runtime: %v", err) } - r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec) + r, err := newNvidiaContainerRuntime(logger.Logger, lowLevelRuntime, ociSpec) if err != nil { return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err) } return r, nil } - -// newOCISpec constructs an OCI spec for the provided arguments -func newOCISpec(argv []string) (oci.Spec, error) { - bundleDir, err := oci.GetBundleDir(argv) - if err != nil { - return nil, fmt.Errorf("error parsing command line arguments: %v", err) - } - logger.Infof("Using bundle directory: %v", bundleDir) - - ociSpecPath := oci.GetSpecFilePath(bundleDir) - logger.Infof("Using OCI specification file path: %v", ociSpecPath) - - ociSpec := oci.NewSpecFromFile(ociSpecPath) - - return ociSpec, nil -} - -// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime -func newRuncRuntime() (oci.Runtime, error) { - return oci.NewLowLevelRuntimeWithLogger( - logger.Logger, - dockerRuncExecutableName, - runcExecutableName, - ) -} diff --git a/internal/oci/runtime_low_level.go b/internal/oci/runtime_low_level.go index 6363d2ec..ac5aea59 100644 --- a/internal/oci/runtime_low_level.go +++ b/internal/oci/runtime_low_level.go @@ -25,19 +25,14 @@ import ( // NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable. // The executable specified is taken from the list of supplied candidates, with the first match -// present in the PATH being selected. -func NewLowLevelRuntime(candidates ...string) (Runtime, error) { - return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...) -} - -// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger. -func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) { +// present in the PATH being selected. A logger is also specified. +func NewLowLevelRuntime(logger *log.Logger, candidates []string) (Runtime, error) { runtimePath, err := findRuntime(logger, candidates) if err != nil { return nil, fmt.Errorf("error locating runtime: %v", err) } - return NewRuntimeForPathWithLogger(logger, runtimePath) + return NewRuntimeForPath(logger, runtimePath) } // findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH. diff --git a/internal/oci/runtime_path.go b/internal/oci/runtime_path.go index abf225b2..102e84f6 100644 --- a/internal/oci/runtime_path.go +++ b/internal/oci/runtime_path.go @@ -34,13 +34,8 @@ type pathRuntime struct { var _ Runtime = (*pathRuntime)(nil) -// NewRuntimeForPath creates a Runtime for the specified path with the standard logger -func NewRuntimeForPath(path string) (Runtime, error) { - return NewRuntimeForPathWithLogger(log.StandardLogger(), path) -} - -// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path -func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) { +// NewRuntimeForPath creates a Runtime for the specified logger and path +func NewRuntimeForPath(logger *log.Logger, path string) (Runtime, error) { info, err := os.Stat(path) if err != nil { return nil, fmt.Errorf("invalid path '%v': %v", path, err) diff --git a/internal/oci/runtime_path_test.go b/internal/oci/runtime_path_test.go index 0d936a3f..ac247a1c 100644 --- a/internal/oci/runtime_path_test.go +++ b/internal/oci/runtime_path_test.go @@ -24,19 +24,21 @@ import ( ) func TestPathRuntimeConstructor(t *testing.T) { - r, err := NewRuntimeForPath("////an/invalid/path") + logger, _ := testlog.NewNullLogger() + + r, err := NewRuntimeForPath(logger, "////an/invalid/path") require.Error(t, err) require.Nil(t, r) - r, err = NewRuntimeForPath("/tmp") + r, err = NewRuntimeForPath(logger, "/tmp") require.Error(t, err) require.Nil(t, r) - r, err = NewRuntimeForPath("/dev/null") + r, err = NewRuntimeForPath(logger, "/dev/null") require.Error(t, err) require.Nil(t, r) - r, err = NewRuntimeForPath("/bin/sh") + r, err = NewRuntimeForPath(logger, "/bin/sh") require.NoError(t, err) f, ok := r.(*pathRuntime) diff --git a/internal/oci/spec.go b/internal/oci/spec.go index 259ce054..ba7e7cee 100644 --- a/internal/oci/spec.go +++ b/internal/oci/spec.go @@ -17,15 +17,20 @@ package oci import ( - oci "github.com/opencontainers/runtime-spec/specs-go" + "fmt" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" ) -// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an -// error. The intention is that the function would modify the spec in-place. -type SpecModifier func(*oci.Spec) error +// SpecModifier defines an interace for modifying a (raw) OCI spec +type SpecModifier interface { + // Modify is a method that accepts a pointer to an OCI Srec and returns an + // error. The intention is that the function would modify the spec in-place. + Modify(*specs.Spec) error +} //go:generate moq -stub -out spec_mock.go . Spec - // Spec defines the operations to be performed on an OCI specification type Spec interface { Load() error @@ -33,3 +38,20 @@ type Spec interface { Modify(SpecModifier) error LookupEnv(string) (string, bool) } + +// NewSpec creates fileSpec based on the command line arguments passed to the +// application using the specified logger. +func NewSpec(logger *logrus.Logger, args []string) (Spec, error) { + bundleDir, err := GetBundleDir(args) + if err != nil { + return nil, fmt.Errorf("error getting bundle directory: %v", err) + } + logger.Infof("Using bundle directory: %v", bundleDir) + + ociSpecPath := GetSpecFilePath(bundleDir) + logger.Infof("Using OCI specification file path: %v", ociSpecPath) + + ociSpec := NewFileSpec(ociSpecPath) + + return ociSpec, nil +} diff --git a/internal/oci/spec_file.go b/internal/oci/spec_file.go index 886e2cb3..ff0cbb46 100644 --- a/internal/oci/spec_file.go +++ b/internal/oci/spec_file.go @@ -21,37 +21,21 @@ import ( "fmt" "io" "os" - "strings" - oci "github.com/opencontainers/runtime-spec/specs-go" + "github.com/opencontainers/runtime-spec/specs-go" ) type fileSpec struct { - *oci.Spec + memorySpec path string } var _ Spec = (*fileSpec)(nil) -// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the -// application -func NewSpecFromArgs(args []string) (Spec, string, error) { - bundleDir, err := GetBundleDir(args) - if err != nil { - return nil, "", fmt.Errorf("error getting bundle directory: %v", err) - } - - ociSpecPath := GetSpecFilePath(bundleDir) - - ociSpec := NewSpecFromFile(ociSpecPath) - - return ociSpec, bundleDir, nil -} - -// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec. +// NewFileSpec creates an object that encapsulates a file-backed OCI spec. // This can be used to read from the file, modify the spec, and write to the // same file. -func NewSpecFromFile(filepath string) Spec { +func NewFileSpec(filepath string) Spec { oci := fileSpec{ path: filepath, } @@ -68,29 +52,31 @@ func (s *fileSpec) Load() error { } defer specFile.Close() - return s.loadFrom(specFile) -} - -// loadFrom reads the contents of the OCI spec from the specified io.Reader. -func (s *fileSpec) loadFrom(reader io.Reader) error { - decoder := json.NewDecoder(reader) - - var spec oci.Spec - err := decoder.Decode(&spec) + spec, err := loadFrom(specFile) if err != nil { - return fmt.Errorf("error reading OCI specification: %v", err) + return fmt.Errorf("error loading OCI specification from file: %v", err) } - - s.Spec = &spec + s.Spec = spec return nil } -// Modify applies the specified SpecModifier to the stored OCI specification. -func (s *fileSpec) Modify(f SpecModifier) error { - if s.Spec == nil { - return fmt.Errorf("no spec loaded for modification") +// loadFrom reads the contents of the OCI spec from the specified io.Reader. +func loadFrom(reader io.Reader) (*specs.Spec, error) { + decoder := json.NewDecoder(reader) + + var spec specs.Spec + + err := decoder.Decode(&spec) + if err != nil { + return nil, fmt.Errorf("error reading OCI specification: %v", err) } - return f(s.Spec) + + return &spec, nil +} + +// Modify applies the specified SpecModifier to the stored OCI specification. +func (s *fileSpec) Modify(m SpecModifier) error { + return s.memorySpec.Modify(m) } // Flush writes the stored OCI specification to the filepath specifed by the path member. @@ -106,48 +92,20 @@ func (s fileSpec) Flush() error { } defer specFile.Close() - return s.flushTo(specFile) + return flushTo(s.Spec, specFile) } // flushTo writes the stored OCI specification to the specified io.Writer. -func (s fileSpec) flushTo(writer io.Writer) error { - if s.Spec == nil { +func flushTo(spec *specs.Spec, writer io.Writer) error { + if spec == nil { return nil } encoder := json.NewEncoder(writer) - err := encoder.Encode(s.Spec) + err := encoder.Encode(spec) if err != nil { return fmt.Errorf("error writing OCI specification: %v", err) } return nil } - -// LookupEnv mirrors os.LookupEnv for the OCI specification. It -// retrieves the value of the environment variable named -// by the key. If the variable is present in the environment the -// value (which may be empty) is returned and the boolean is true. -// Otherwise the returned value will be empty and the boolean will -// be false. -func (s fileSpec) LookupEnv(key string) (string, bool) { - if s.Spec == nil || s.Spec.Process == nil { - return "", false - } - - for _, env := range s.Spec.Process.Env { - if !strings.HasPrefix(env, key) { - continue - } - - parts := strings.SplitN(env, "=", 2) - if parts[0] == key { - if len(parts) < 2 { - return "", true - } - return parts[1], true - } - } - - return "", false -} diff --git a/internal/oci/spec_file_test.go b/internal/oci/spec_file_test.go index fe273824..94dfb3b3 100644 --- a/internal/oci/spec_file_test.go +++ b/internal/oci/spec_file_test.go @@ -25,111 +25,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestLookupEnv(t *testing.T) { - const envName = "TEST_ENV" - testCases := []struct { - spec *specs.Spec - expectedValue string - expectedExits bool - }{ - { - // nil spec - spec: nil, - expectedValue: "", - expectedExits: false, - }, - { - // nil process - spec: &specs.Spec{}, - expectedValue: "", - expectedExits: false, - }, - { - // nil env - spec: &specs.Spec{ - Process: &specs.Process{}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // empty env - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // different env set - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // same prefix - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // same suffix - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}}, - }, - expectedValue: "", - expectedExits: false, - }, - { - // set blank - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV="}}, - }, - expectedValue: "", - expectedExits: true, - }, - { - // set no-equals - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV"}}, - }, - expectedValue: "", - expectedExits: true, - }, - { - // set value - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV=something"}}, - }, - expectedValue: "something", - expectedExits: true, - }, - { - // set with equals - spec: &specs.Spec{ - Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}}, - }, - expectedValue: "something=somethingelse", - expectedExits: true, - }, - } - - for i, tc := range testCases { - spec := fileSpec{ - Spec: tc.spec, - } - - value, exists := spec.LookupEnv(envName) - - require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc) - require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc) - } -} - func TestLoadFrom(t *testing.T) { testCases := []struct { contents []byte @@ -148,8 +43,8 @@ func TestLoadFrom(t *testing.T) { } for i, tc := range testCases { - spec := fileSpec{} - err := spec.loadFrom(bytes.NewReader(tc.contents)) + var spec *specs.Spec + spec, err := loadFrom(bytes.NewReader(tc.contents)) if tc.isError { require.Error(t, err, "%d: %v", i, tc) @@ -158,9 +53,9 @@ func TestLoadFrom(t *testing.T) { } if tc.spec == nil { - require.Nil(t, spec.Spec, "%d: %v", i, tc) + require.Nil(t, spec, "%d: %v", i, tc) } else { - require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc) + require.EqualValues(t, tc.spec, spec, "%d: %v", i, tc) } } } @@ -183,8 +78,7 @@ func TestFlushTo(t *testing.T) { for i, tc := range testCases { buffer := bytes.Buffer{} - spec := fileSpec{Spec: tc.spec} - err := spec.flushTo(&buffer) + err := flushTo(tc.spec, &buffer) if tc.isError { require.Error(t, err, "%d: %v", i, tc) @@ -196,53 +90,10 @@ func TestFlushTo(t *testing.T) { } // Add a simple test for a writer that returns an error when writing - spec := fileSpec{Spec: &specs.Spec{}} - err := spec.flushTo(errorWriter{}) + err := flushTo(&specs.Spec{}, errorWriter{}) require.Error(t, err) } -func TestModify(t *testing.T) { - - testCases := []struct { - spec *specs.Spec - modifierError error - }{ - { - spec: nil, - }, - { - spec: &specs.Spec{}, - }, - { - spec: &specs.Spec{}, - modifierError: fmt.Errorf("error in modifier"), - }, - } - - for i, tc := range testCases { - spec := fileSpec{Spec: tc.spec} - - modifier := func(spec *specs.Spec) error { - if tc.modifierError == nil { - spec.Version = "updated" - } - return tc.modifierError - } - - err := spec.Modify(modifier) - - if tc.spec == nil { - require.Error(t, err, "%d: %v", i, tc) - } else if tc.modifierError != nil { - require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc) - require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc) - } else { - require.NoError(t, err, "%d: %v", i, tc) - require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc) - } - } -} - // errorWriter implements the io.Writer interface, always returning an error when // writing. type errorWriter struct{} diff --git a/internal/oci/spec_memory.go b/internal/oci/spec_memory.go new file mode 100644 index 00000000..ce94447e --- /dev/null +++ b/internal/oci/spec_memory.go @@ -0,0 +1,83 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package oci + +import ( + "fmt" + "strings" + + "github.com/opencontainers/runtime-spec/specs-go" +) + +type memorySpec struct { + *specs.Spec +} + +// NewMemorySpec creates a Spec instance from the specified OCI spec +func NewMemorySpec(spec *specs.Spec) Spec { + s := memorySpec{ + Spec: spec, + } + + return &s +} + +// Load is a no-op for the memorySpec spec +func (s *memorySpec) Load() error { + return nil +} + +// Flush is a no-op for the memorySpec spec +func (s *memorySpec) Flush() error { + return nil +} + +// Modify applies the specified SpecModifier to the stored OCI specification. +func (s *memorySpec) Modify(m SpecModifier) error { + if s.Spec == nil { + return fmt.Errorf("cannot modify nil spec") + } + return m.Modify(s.Spec) +} + +// LookupEnv mirrors os.LookupEnv for the OCI specification. It +// retrieves the value of the environment variable named +// by the key. If the variable is present in the environment the +// value (which may be empty) is returned and the boolean is true. +// Otherwise the returned value will be empty and the boolean will +// be false. +func (s memorySpec) LookupEnv(key string) (string, bool) { + if s.Spec == nil || s.Spec.Process == nil { + return "", false + } + + for _, env := range s.Spec.Process.Env { + if !strings.HasPrefix(env, key) { + continue + } + + parts := strings.SplitN(env, "=", 2) + if parts[0] == key { + if len(parts) < 2 { + return "", true + } + return parts[1], true + } + } + + return "", false +} diff --git a/internal/oci/spec_memory_test.go b/internal/oci/spec_memory_test.go new file mode 100644 index 00000000..a43944a4 --- /dev/null +++ b/internal/oci/spec_memory_test.go @@ -0,0 +1,178 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package oci + +import ( + "fmt" + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/stretchr/testify/require" +) + +func TestLookupEnv(t *testing.T) { + const envName = "TEST_ENV" + testCases := []struct { + spec *specs.Spec + expectedValue string + expectedExits bool + }{ + { + // nil spec + spec: nil, + expectedValue: "", + expectedExits: false, + }, + { + // nil process + spec: &specs.Spec{}, + expectedValue: "", + expectedExits: false, + }, + { + // nil env + spec: &specs.Spec{ + Process: &specs.Process{}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // empty env + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // different env set + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same prefix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // same suffix + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}}, + }, + expectedValue: "", + expectedExits: false, + }, + { + // set blank + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV="}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set no-equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV"}}, + }, + expectedValue: "", + expectedExits: true, + }, + { + // set value + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something"}}, + }, + expectedValue: "something", + expectedExits: true, + }, + { + // set with equals + spec: &specs.Spec{ + Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}}, + }, + expectedValue: "something=somethingelse", + expectedExits: true, + }, + } + + for i, tc := range testCases { + spec := memorySpec{ + Spec: tc.spec, + } + + value, exists := spec.LookupEnv(envName) + + require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc) + require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc) + } +} + +func TestModify(t *testing.T) { + + testCases := []struct { + spec *specs.Spec + modifierError error + }{ + { + spec: nil, + }, + { + spec: &specs.Spec{}, + }, + { + spec: &specs.Spec{}, + modifierError: fmt.Errorf("error in modifier"), + }, + } + + for i, tc := range testCases { + spec := NewMemorySpec(tc.spec).(*memorySpec) + + err := spec.Modify(modifier{tc.modifierError}) + + if tc.spec == nil { + require.Error(t, err, "%d: %v", i, tc) + } else if tc.modifierError != nil { + require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc) + require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc) + } else { + require.NoError(t, err, "%d: %v", i, tc) + require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc) + } + } +} + +// TODO: Ideally we would generated a mock for the SpecModifer too. This causes +// an import cycle and we define a local type as a work-around. +type modifier struct { + modifierError error +} + +func (m modifier) Modify(spec *specs.Spec) error { + if m.modifierError == nil { + spec.Version = "updated" + } + return m.modifierError +} diff --git a/internal/oci/spec_test.go b/internal/oci/spec_test.go index 03d2d301..f17f3d3c 100644 --- a/internal/oci/spec_test.go +++ b/internal/oci/spec_test.go @@ -20,7 +20,7 @@ func TestMaintainSpec(t *testing.T) { for _, f := range files { inputSpecPath := filepath.Join(moduleRoot, "test/input", f) - spec := NewSpecFromFile(inputSpecPath).(*fileSpec) + spec := NewFileSpec(inputSpecPath).(*fileSpec) spec.Load() diff --git a/internal/runtime/runtime_modifier.go b/internal/runtime/runtime_modifier.go new file mode 100644 index 00000000..86accba0 --- /dev/null +++ b/internal/runtime/runtime_modifier.go @@ -0,0 +1,81 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package runtime + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + log "github.com/sirupsen/logrus" +) + +type modifyingRuntimeWrapper struct { + logger *log.Logger + runtime oci.Runtime + ociSpec oci.Spec + modifier oci.SpecModifier +} + +var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil) + +// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification +// before invoking the wrapped runtime. +func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime { + rt := modifyingRuntimeWrapper{ + logger: logger, + runtime: runtime, + ociSpec: spec, + modifier: modifier, + } + return &rt +} + +// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing +// into the wrapped runtime. +func (r *modifyingRuntimeWrapper) Exec(args []string) error { + if oci.HasCreateSubcommand(args) { + err := r.modify() + if err != nil { + return fmt.Errorf("could not apply required modification to OCI specification: %v", err) + } + r.logger.Infof("Applied required modification to OCI specification") + } else { + r.logger.Infof("No modification of OCI specification required") + } + + r.logger.Infof("Forwarding command to runtime") + return r.runtime.Exec(args) +} + +// modify loads, modifies, and flushes the OCI specification using the defined Modifier +func (r *modifyingRuntimeWrapper) modify() error { + err := r.ociSpec.Load() + if err != nil { + return fmt.Errorf("error loading OCI specification for modification: %v", err) + } + + err = r.ociSpec.Modify(r.modifier) + if err != nil { + return fmt.Errorf("error modifying OCI spec: %v", err) + } + + err = r.ociSpec.Flush() + if err != nil { + return fmt.Errorf("error writing modified OCI specification: %v", err) + } + return nil +} diff --git a/internal/runtime/runtime_modifier_test.go b/internal/runtime/runtime_modifier_test.go new file mode 100644 index 00000000..5ab77a04 --- /dev/null +++ b/internal/runtime/runtime_modifier_test.go @@ -0,0 +1,116 @@ +/* +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package runtime + +import ( + "fmt" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestExec(t *testing.T) { + logger, hook := testlog.NewNullLogger() + + testCases := []struct { + description string + shouldModify bool + shouldFlush bool + shouldForward bool + args []string + modifyError error + writeError error + }{ + { + description: "no args forwards", + shouldModify: false, + shouldFlush: false, + shouldForward: true, + }, + { + description: "create modifies", + args: []string{"create"}, + shouldModify: true, + shouldFlush: true, + shouldForward: true, + }, + { + description: "modify error does not write or forward", + args: []string{"create"}, + modifyError: fmt.Errorf("error modifying"), + shouldModify: true, + shouldFlush: false, + shouldForward: false, + }, + { + description: "write error does not forward", + args: []string{"create"}, + writeError: fmt.Errorf("error writing"), + shouldModify: true, + shouldFlush: true, + shouldForward: false, + }, + } + + for _, tc := range testCases { + hook.Reset() + + t.Run(tc.description, func(t *testing.T) { + runtimeMock := &oci.RuntimeMock{} + specMock := &oci.SpecMock{ + ModifyFunc: func(specModifier oci.SpecModifier) error { + return tc.modifyError + }, + FlushFunc: func() error { + return tc.writeError + }, + } + + shim := NewModifyingRuntimeWrapper( + logger, + runtimeMock, + specMock, + // TODO: We should test the interactions with the SpecModifier too + nil) + + err := shim.Exec(tc.args) + if tc.modifyError != nil || tc.writeError != nil { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + if tc.shouldModify { + require.Equal(t, 1, len(specMock.ModifyCalls())) + } else { + require.Equal(t, 0, len(specMock.ModifyCalls())) + } + if tc.shouldFlush { + require.Equal(t, 1, len(specMock.FlushCalls())) + } else { + require.Equal(t, 0, len(specMock.FlushCalls())) + } + if tc.shouldForward { + require.Equal(t, 1, len(runtimeMock.ExecCalls())) + } else { + require.Equal(t, 0, len(runtimeMock.ExecCalls())) + } + }) + } +}