Import modifying runtime abstraction from experimental runtime

This change imports the modifying runtime abstraction from the
experimental branch. This encapsulates the checks for whether
modification is required, and forwards the loaded spec to
the specified modifier. This allows for the same code to be
reused when performing more complex modifications.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2022-01-31 15:11:36 +01:00
parent bf8c3bab72
commit 4177fddcc4
15 changed files with 584 additions and 388 deletions

View File

@ -167,11 +167,11 @@ func TestDuplicateHook(t *testing.T) {
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
}
// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
// testing.
func addNVIDIAHook(spec *specs.Spec) error {
r := nvidiaContainerRuntime{logger: logger.Logger}
return r.addNVIDIAHook(spec)
m := addHookModifier{logger: logger.Logger}
return m.Modify(spec)
}
func (c testConfig) getRuntimeSpec() (specs.Spec, error) {

View File

@ -17,88 +17,41 @@
package main
import (
"fmt"
"os"
"os/exec"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/runtime"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
)
// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally
// modifying the specified OCI specification before invoking the runtime.
type nvidiaContainerRuntime struct {
logger *log.Logger
runtime oci.Runtime
ociSpec oci.Spec
// newNvidiaContainerRuntime is a constructor for a standard runtime shim. This uses
// a ModifyingRuntimeWrapper to apply the required modifications before execing to the
// specified low-level runtime
func newNvidiaContainerRuntime(logger *log.Logger, lowlevelRuntime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
modifier := addHookModifier{logger: logger}
r := runtime.NewModifyingRuntimeWrapper(
logger,
lowlevelRuntime,
ociSpec,
modifier,
)
return r, nil
}
var _ oci.Runtime = (*nvidiaContainerRuntime)(nil)
// newNvidiaContainerRuntime is a constructor for a standard runtime shim.
func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
r := nvidiaContainerRuntime{
logger: logger,
runtime: runtime,
ociSpec: ociSpec,
}
return &r, nil
// addHookModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
// prestart hook. If the hook is already present, no modification is made.
type addHookModifier struct {
logger *log.Logger
}
// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications
// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then
// forwarded to the underlying runtime's Exec method.
func (r nvidiaContainerRuntime) Exec(args []string) error {
if r.modificationRequired(args) {
err := r.modifyOCISpec()
if err != nil {
return fmt.Errorf("error modifying OCI spec: %v", err)
}
}
r.logger.Println("Forwarding command to runtime")
return r.runtime.Exec(args)
}
// modificationRequired checks the intput arguments to determine whether a modification
// to the OCI spec is required.
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
if oci.HasCreateSubcommand(args) {
r.logger.Infof("'create' command detected; modification required")
return true
}
r.logger.Infof("No modification required")
return false
}
// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime
// struct. The spec is modified in-place and written to the same file as the input after
// modifcationas are applied.
func (r nvidiaContainerRuntime) modifyOCISpec() error {
err := r.ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification for modification: %v", err)
}
err = r.ociSpec.Modify(r.addNVIDIAHook)
if err != nil {
return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err)
}
err = r.ociSpec.Flush()
if err != nil {
return fmt.Errorf("error writing modified OCI specification: %v", err)
}
return nil
}
// addNVIDIAHook modifies the specified OCI specification in-place, inserting a
// prestart hook.
func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
// as a prestart hook.
func (m addHookModifier) Modify(spec *specs.Spec) error {
path, err := exec.LookPath("nvidia-container-runtime-hook")
if err != nil {
path = hookDefaultFilePath
@ -108,7 +61,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
}
}
r.logger.Printf("prestart hook path: %s\n", path)
m.logger.Printf("prestart hook path: %s\n", path)
args := []string{path}
if spec.Hooks == nil {
@ -118,7 +71,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
continue
}
r.logger.Println("existing nvidia prestart hook in OCI spec file")
m.logger.Println("existing nvidia prestart hook in OCI spec file")
return nil
}
}

View File

@ -29,9 +29,8 @@ import (
func TestAddNvidiaHook(t *testing.T) {
logger, logHook := testlog.NewNullLogger()
shim := nvidiaContainerRuntime{
logger: logger,
}
mockRuntime := &oci.RuntimeMock{}
testCases := []struct {
spec *specs.Spec
@ -75,7 +74,16 @@ func TestAddNvidiaHook(t *testing.T) {
numPrestartHooks = len(tc.spec.Hooks.Prestart)
}
err := shim.addNVIDIAHook(tc.spec)
shim, err := newNvidiaContainerRuntime(
logger,
mockRuntime,
oci.NewMemorySpec(tc.spec),
)
require.NoError(t, err)
err = shim.Exec([]string{"runtime", "create"})
require.NoError(t, err)
if tc.errorPrefix == "" {
require.NoErrorf(t, err, "%d: %v", i, tc)
@ -106,45 +114,39 @@ func TestAddNvidiaHook(t *testing.T) {
func TestNvidiaContainerRuntime(t *testing.T) {
logger, hook := testlog.NewNullLogger()
mockRuntime := &oci.RuntimeMock{}
testCases := []struct {
shim nvidiaContainerRuntime
shouldModify bool
args []string
modifyError error
writeError error
}{
{
shim: nvidiaContainerRuntime{},
shouldModify: false,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
shouldModify: true,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"--bundle=create"},
shouldModify: false,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"--bundle", "create"},
shouldModify: false,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
shouldModify: true,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
modifyError: fmt.Errorf("error modifying"),
shouldModify: true,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
writeError: fmt.Errorf("error writing"),
shouldModify: true,
@ -152,10 +154,8 @@ func TestNvidiaContainerRuntime(t *testing.T) {
}
for i, tc := range testCases {
tc.shim.logger = logger
hook.Reset()
ociMock := &oci.SpecMock{
specMock := &oci.SpecMock{
ModifyFunc: func(specModifier oci.SpecModifier) error {
return tc.modifyError
},
@ -163,12 +163,11 @@ func TestNvidiaContainerRuntime(t *testing.T) {
return tc.writeError
},
}
require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)
tc.shim.ociSpec = ociMock
tc.shim.runtime = &MockShim{}
shim, err := newNvidiaContainerRuntime(logger, mockRuntime, specMock)
require.NoError(t, err)
err := tc.shim.Exec(tc.args)
err = shim.Exec(tc.args)
if tc.modifyError != nil || tc.writeError != nil {
require.Error(t, err, "%d: %v", i, tc)
} else {
@ -176,28 +175,16 @@ func TestNvidiaContainerRuntime(t *testing.T) {
}
if tc.shouldModify {
require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
require.Equal(t, 1, len(specMock.ModifyCalls()), "%d: %v", i, tc)
} else {
require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
require.Equal(t, 0, len(specMock.ModifyCalls()), "%d: %v", i, tc)
}
writeExpected := tc.shouldModify && tc.modifyError == nil
if writeExpected {
require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc)
require.Equal(t, 1, len(specMock.FlushCalls()), "%d: %v", i, tc)
} else {
require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc)
require.Equal(t, 0, len(specMock.FlushCalls()), "%d: %v", i, tc)
}
}
}
type MockShim struct {
called bool
args []string
returnError error
}
func (m *MockShim) Exec(args []string) error {
m.called = true
m.args = args
return m.returnError
}

View File

@ -23,52 +23,27 @@ import (
)
const (
ociSpecFileName = "config.json"
dockerRuncExecutableName = "docker-runc"
runcExecutableName = "runc"
)
// newRuntime is a factory method that constructs a runtime based on the selected configuration.
func newRuntime(argv []string) (oci.Runtime, error) {
ociSpec, err := newOCISpec(argv)
ociSpec, err := oci.NewSpec(logger.Logger, argv)
if err != nil {
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
runc, err := newRuncRuntime()
lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName}
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger.Logger, lowLevelRuntimeCandidates)
if err != nil {
return nil, fmt.Errorf("error constructing runc runtime: %v", err)
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
}
r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec)
r, err := newNvidiaContainerRuntime(logger.Logger, lowLevelRuntime, ociSpec)
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err)
}
return r, nil
}
// newOCISpec constructs an OCI spec for the provided arguments
func newOCISpec(argv []string) (oci.Spec, error) {
bundleDir, err := oci.GetBundleDir(argv)
if err != nil {
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
}
logger.Infof("Using bundle directory: %v", bundleDir)
ociSpecPath := oci.GetSpecFilePath(bundleDir)
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
ociSpec := oci.NewSpecFromFile(ociSpecPath)
return ociSpec, nil
}
// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
func newRuncRuntime() (oci.Runtime, error) {
return oci.NewLowLevelRuntimeWithLogger(
logger.Logger,
dockerRuncExecutableName,
runcExecutableName,
)
}

View File

@ -25,19 +25,14 @@ import (
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
// The executable specified is taken from the list of supplied candidates, with the first match
// present in the PATH being selected.
func NewLowLevelRuntime(candidates ...string) (Runtime, error) {
return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...)
}
// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger.
func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) {
// present in the PATH being selected. A logger is also specified.
func NewLowLevelRuntime(logger *log.Logger, candidates []string) (Runtime, error) {
runtimePath, err := findRuntime(logger, candidates)
if err != nil {
return nil, fmt.Errorf("error locating runtime: %v", err)
}
return NewRuntimeForPathWithLogger(logger, runtimePath)
return NewRuntimeForPath(logger, runtimePath)
}
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.

View File

@ -34,13 +34,8 @@ type pathRuntime struct {
var _ Runtime = (*pathRuntime)(nil)
// NewRuntimeForPath creates a Runtime for the specified path with the standard logger
func NewRuntimeForPath(path string) (Runtime, error) {
return NewRuntimeForPathWithLogger(log.StandardLogger(), path)
}
// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path
func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) {
// NewRuntimeForPath creates a Runtime for the specified logger and path
func NewRuntimeForPath(logger *log.Logger, path string) (Runtime, error) {
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("invalid path '%v': %v", path, err)

View File

@ -24,19 +24,21 @@ import (
)
func TestPathRuntimeConstructor(t *testing.T) {
r, err := NewRuntimeForPath("////an/invalid/path")
logger, _ := testlog.NewNullLogger()
r, err := NewRuntimeForPath(logger, "////an/invalid/path")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/tmp")
r, err = NewRuntimeForPath(logger, "/tmp")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/dev/null")
r, err = NewRuntimeForPath(logger, "/dev/null")
require.Error(t, err)
require.Nil(t, r)
r, err = NewRuntimeForPath("/bin/sh")
r, err = NewRuntimeForPath(logger, "/bin/sh")
require.NoError(t, err)
f, ok := r.(*pathRuntime)

View File

@ -17,15 +17,20 @@
package oci
import (
oci "github.com/opencontainers/runtime-spec/specs-go"
"fmt"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
)
// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an
// error. The intention is that the function would modify the spec in-place.
type SpecModifier func(*oci.Spec) error
// SpecModifier defines an interace for modifying a (raw) OCI spec
type SpecModifier interface {
// Modify is a method that accepts a pointer to an OCI Srec and returns an
// error. The intention is that the function would modify the spec in-place.
Modify(*specs.Spec) error
}
//go:generate moq -stub -out spec_mock.go . Spec
// Spec defines the operations to be performed on an OCI specification
type Spec interface {
Load() error
@ -33,3 +38,20 @@ type Spec interface {
Modify(SpecModifier) error
LookupEnv(string) (string, bool)
}
// NewSpec creates fileSpec based on the command line arguments passed to the
// application using the specified logger.
func NewSpec(logger *logrus.Logger, args []string) (Spec, error) {
bundleDir, err := GetBundleDir(args)
if err != nil {
return nil, fmt.Errorf("error getting bundle directory: %v", err)
}
logger.Infof("Using bundle directory: %v", bundleDir)
ociSpecPath := GetSpecFilePath(bundleDir)
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
ociSpec := NewFileSpec(ociSpecPath)
return ociSpec, nil
}

View File

@ -21,37 +21,21 @@ import (
"fmt"
"io"
"os"
"strings"
oci "github.com/opencontainers/runtime-spec/specs-go"
"github.com/opencontainers/runtime-spec/specs-go"
)
type fileSpec struct {
*oci.Spec
memorySpec
path string
}
var _ Spec = (*fileSpec)(nil)
// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the
// application
func NewSpecFromArgs(args []string) (Spec, string, error) {
bundleDir, err := GetBundleDir(args)
if err != nil {
return nil, "", fmt.Errorf("error getting bundle directory: %v", err)
}
ociSpecPath := GetSpecFilePath(bundleDir)
ociSpec := NewSpecFromFile(ociSpecPath)
return ociSpec, bundleDir, nil
}
// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec.
// NewFileSpec creates an object that encapsulates a file-backed OCI spec.
// This can be used to read from the file, modify the spec, and write to the
// same file.
func NewSpecFromFile(filepath string) Spec {
func NewFileSpec(filepath string) Spec {
oci := fileSpec{
path: filepath,
}
@ -68,29 +52,31 @@ func (s *fileSpec) Load() error {
}
defer specFile.Close()
return s.loadFrom(specFile)
}
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
func (s *fileSpec) loadFrom(reader io.Reader) error {
decoder := json.NewDecoder(reader)
var spec oci.Spec
err := decoder.Decode(&spec)
spec, err := loadFrom(specFile)
if err != nil {
return fmt.Errorf("error reading OCI specification: %v", err)
return fmt.Errorf("error loading OCI specification from file: %v", err)
}
s.Spec = &spec
s.Spec = spec
return nil
}
// Modify applies the specified SpecModifier to the stored OCI specification.
func (s *fileSpec) Modify(f SpecModifier) error {
if s.Spec == nil {
return fmt.Errorf("no spec loaded for modification")
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
func loadFrom(reader io.Reader) (*specs.Spec, error) {
decoder := json.NewDecoder(reader)
var spec specs.Spec
err := decoder.Decode(&spec)
if err != nil {
return nil, fmt.Errorf("error reading OCI specification: %v", err)
}
return f(s.Spec)
return &spec, nil
}
// Modify applies the specified SpecModifier to the stored OCI specification.
func (s *fileSpec) Modify(m SpecModifier) error {
return s.memorySpec.Modify(m)
}
// Flush writes the stored OCI specification to the filepath specifed by the path member.
@ -106,48 +92,20 @@ func (s fileSpec) Flush() error {
}
defer specFile.Close()
return s.flushTo(specFile)
return flushTo(s.Spec, specFile)
}
// flushTo writes the stored OCI specification to the specified io.Writer.
func (s fileSpec) flushTo(writer io.Writer) error {
if s.Spec == nil {
func flushTo(spec *specs.Spec, writer io.Writer) error {
if spec == nil {
return nil
}
encoder := json.NewEncoder(writer)
err := encoder.Encode(s.Spec)
err := encoder.Encode(spec)
if err != nil {
return fmt.Errorf("error writing OCI specification: %v", err)
}
return nil
}
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
// retrieves the value of the environment variable named
// by the key. If the variable is present in the environment the
// value (which may be empty) is returned and the boolean is true.
// Otherwise the returned value will be empty and the boolean will
// be false.
func (s fileSpec) LookupEnv(key string) (string, bool) {
if s.Spec == nil || s.Spec.Process == nil {
return "", false
}
for _, env := range s.Spec.Process.Env {
if !strings.HasPrefix(env, key) {
continue
}
parts := strings.SplitN(env, "=", 2)
if parts[0] == key {
if len(parts) < 2 {
return "", true
}
return parts[1], true
}
}
return "", false
}

View File

@ -25,111 +25,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestLookupEnv(t *testing.T) {
const envName = "TEST_ENV"
testCases := []struct {
spec *specs.Spec
expectedValue string
expectedExits bool
}{
{
// nil spec
spec: nil,
expectedValue: "",
expectedExits: false,
},
{
// nil process
spec: &specs.Spec{},
expectedValue: "",
expectedExits: false,
},
{
// nil env
spec: &specs.Spec{
Process: &specs.Process{},
},
expectedValue: "",
expectedExits: false,
},
{
// empty env
spec: &specs.Spec{
Process: &specs.Process{Env: []string{}},
},
expectedValue: "",
expectedExits: false,
},
{
// different env set
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same prefix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same suffix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// set blank
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV="}},
},
expectedValue: "",
expectedExits: true,
},
{
// set no-equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV"}},
},
expectedValue: "",
expectedExits: true,
},
{
// set value
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
},
expectedValue: "something",
expectedExits: true,
},
{
// set with equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
},
expectedValue: "something=somethingelse",
expectedExits: true,
},
}
for i, tc := range testCases {
spec := fileSpec{
Spec: tc.spec,
}
value, exists := spec.LookupEnv(envName)
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
}
}
func TestLoadFrom(t *testing.T) {
testCases := []struct {
contents []byte
@ -148,8 +43,8 @@ func TestLoadFrom(t *testing.T) {
}
for i, tc := range testCases {
spec := fileSpec{}
err := spec.loadFrom(bytes.NewReader(tc.contents))
var spec *specs.Spec
spec, err := loadFrom(bytes.NewReader(tc.contents))
if tc.isError {
require.Error(t, err, "%d: %v", i, tc)
@ -158,9 +53,9 @@ func TestLoadFrom(t *testing.T) {
}
if tc.spec == nil {
require.Nil(t, spec.Spec, "%d: %v", i, tc)
require.Nil(t, spec, "%d: %v", i, tc)
} else {
require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc)
require.EqualValues(t, tc.spec, spec, "%d: %v", i, tc)
}
}
}
@ -183,8 +78,7 @@ func TestFlushTo(t *testing.T) {
for i, tc := range testCases {
buffer := bytes.Buffer{}
spec := fileSpec{Spec: tc.spec}
err := spec.flushTo(&buffer)
err := flushTo(tc.spec, &buffer)
if tc.isError {
require.Error(t, err, "%d: %v", i, tc)
@ -196,53 +90,10 @@ func TestFlushTo(t *testing.T) {
}
// Add a simple test for a writer that returns an error when writing
spec := fileSpec{Spec: &specs.Spec{}}
err := spec.flushTo(errorWriter{})
err := flushTo(&specs.Spec{}, errorWriter{})
require.Error(t, err)
}
func TestModify(t *testing.T) {
testCases := []struct {
spec *specs.Spec
modifierError error
}{
{
spec: nil,
},
{
spec: &specs.Spec{},
},
{
spec: &specs.Spec{},
modifierError: fmt.Errorf("error in modifier"),
},
}
for i, tc := range testCases {
spec := fileSpec{Spec: tc.spec}
modifier := func(spec *specs.Spec) error {
if tc.modifierError == nil {
spec.Version = "updated"
}
return tc.modifierError
}
err := spec.Modify(modifier)
if tc.spec == nil {
require.Error(t, err, "%d: %v", i, tc)
} else if tc.modifierError != nil {
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
}
}
}
// errorWriter implements the io.Writer interface, always returning an error when
// writing.
type errorWriter struct{}

View File

@ -0,0 +1,83 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package oci
import (
"fmt"
"strings"
"github.com/opencontainers/runtime-spec/specs-go"
)
type memorySpec struct {
*specs.Spec
}
// NewMemorySpec creates a Spec instance from the specified OCI spec
func NewMemorySpec(spec *specs.Spec) Spec {
s := memorySpec{
Spec: spec,
}
return &s
}
// Load is a no-op for the memorySpec spec
func (s *memorySpec) Load() error {
return nil
}
// Flush is a no-op for the memorySpec spec
func (s *memorySpec) Flush() error {
return nil
}
// Modify applies the specified SpecModifier to the stored OCI specification.
func (s *memorySpec) Modify(m SpecModifier) error {
if s.Spec == nil {
return fmt.Errorf("cannot modify nil spec")
}
return m.Modify(s.Spec)
}
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
// retrieves the value of the environment variable named
// by the key. If the variable is present in the environment the
// value (which may be empty) is returned and the boolean is true.
// Otherwise the returned value will be empty and the boolean will
// be false.
func (s memorySpec) LookupEnv(key string) (string, bool) {
if s.Spec == nil || s.Spec.Process == nil {
return "", false
}
for _, env := range s.Spec.Process.Env {
if !strings.HasPrefix(env, key) {
continue
}
parts := strings.SplitN(env, "=", 2)
if parts[0] == key {
if len(parts) < 2 {
return "", true
}
return parts[1], true
}
}
return "", false
}

View File

@ -0,0 +1,178 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package oci
import (
"fmt"
"testing"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require"
)
func TestLookupEnv(t *testing.T) {
const envName = "TEST_ENV"
testCases := []struct {
spec *specs.Spec
expectedValue string
expectedExits bool
}{
{
// nil spec
spec: nil,
expectedValue: "",
expectedExits: false,
},
{
// nil process
spec: &specs.Spec{},
expectedValue: "",
expectedExits: false,
},
{
// nil env
spec: &specs.Spec{
Process: &specs.Process{},
},
expectedValue: "",
expectedExits: false,
},
{
// empty env
spec: &specs.Spec{
Process: &specs.Process{Env: []string{}},
},
expectedValue: "",
expectedExits: false,
},
{
// different env set
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same prefix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// same suffix
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
},
expectedValue: "",
expectedExits: false,
},
{
// set blank
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV="}},
},
expectedValue: "",
expectedExits: true,
},
{
// set no-equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV"}},
},
expectedValue: "",
expectedExits: true,
},
{
// set value
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
},
expectedValue: "something",
expectedExits: true,
},
{
// set with equals
spec: &specs.Spec{
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
},
expectedValue: "something=somethingelse",
expectedExits: true,
},
}
for i, tc := range testCases {
spec := memorySpec{
Spec: tc.spec,
}
value, exists := spec.LookupEnv(envName)
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
}
}
func TestModify(t *testing.T) {
testCases := []struct {
spec *specs.Spec
modifierError error
}{
{
spec: nil,
},
{
spec: &specs.Spec{},
},
{
spec: &specs.Spec{},
modifierError: fmt.Errorf("error in modifier"),
},
}
for i, tc := range testCases {
spec := NewMemorySpec(tc.spec).(*memorySpec)
err := spec.Modify(modifier{tc.modifierError})
if tc.spec == nil {
require.Error(t, err, "%d: %v", i, tc)
} else if tc.modifierError != nil {
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
}
}
}
// TODO: Ideally we would generated a mock for the SpecModifer too. This causes
// an import cycle and we define a local type as a work-around.
type modifier struct {
modifierError error
}
func (m modifier) Modify(spec *specs.Spec) error {
if m.modifierError == nil {
spec.Version = "updated"
}
return m.modifierError
}

View File

@ -20,7 +20,7 @@ func TestMaintainSpec(t *testing.T) {
for _, f := range files {
inputSpecPath := filepath.Join(moduleRoot, "test/input", f)
spec := NewSpecFromFile(inputSpecPath).(*fileSpec)
spec := NewFileSpec(inputSpecPath).(*fileSpec)
spec.Load()

View File

@ -0,0 +1,81 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
log "github.com/sirupsen/logrus"
)
type modifyingRuntimeWrapper struct {
logger *log.Logger
runtime oci.Runtime
ociSpec oci.Spec
modifier oci.SpecModifier
}
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
// before invoking the wrapped runtime.
func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime {
rt := modifyingRuntimeWrapper{
logger: logger,
runtime: runtime,
ociSpec: spec,
modifier: modifier,
}
return &rt
}
// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing
// into the wrapped runtime.
func (r *modifyingRuntimeWrapper) Exec(args []string) error {
if oci.HasCreateSubcommand(args) {
err := r.modify()
if err != nil {
return fmt.Errorf("could not apply required modification to OCI specification: %v", err)
}
r.logger.Infof("Applied required modification to OCI specification")
} else {
r.logger.Infof("No modification of OCI specification required")
}
r.logger.Infof("Forwarding command to runtime")
return r.runtime.Exec(args)
}
// modify loads, modifies, and flushes the OCI specification using the defined Modifier
func (r *modifyingRuntimeWrapper) modify() error {
err := r.ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification for modification: %v", err)
}
err = r.ociSpec.Modify(r.modifier)
if err != nil {
return fmt.Errorf("error modifying OCI spec: %v", err)
}
err = r.ociSpec.Flush()
if err != nil {
return fmt.Errorf("error writing modified OCI specification: %v", err)
}
return nil
}

View File

@ -0,0 +1,116 @@
/*
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestExec(t *testing.T) {
logger, hook := testlog.NewNullLogger()
testCases := []struct {
description string
shouldModify bool
shouldFlush bool
shouldForward bool
args []string
modifyError error
writeError error
}{
{
description: "no args forwards",
shouldModify: false,
shouldFlush: false,
shouldForward: true,
},
{
description: "create modifies",
args: []string{"create"},
shouldModify: true,
shouldFlush: true,
shouldForward: true,
},
{
description: "modify error does not write or forward",
args: []string{"create"},
modifyError: fmt.Errorf("error modifying"),
shouldModify: true,
shouldFlush: false,
shouldForward: false,
},
{
description: "write error does not forward",
args: []string{"create"},
writeError: fmt.Errorf("error writing"),
shouldModify: true,
shouldFlush: true,
shouldForward: false,
},
}
for _, tc := range testCases {
hook.Reset()
t.Run(tc.description, func(t *testing.T) {
runtimeMock := &oci.RuntimeMock{}
specMock := &oci.SpecMock{
ModifyFunc: func(specModifier oci.SpecModifier) error {
return tc.modifyError
},
FlushFunc: func() error {
return tc.writeError
},
}
shim := NewModifyingRuntimeWrapper(
logger,
runtimeMock,
specMock,
// TODO: We should test the interactions with the SpecModifier too
nil)
err := shim.Exec(tc.args)
if tc.modifyError != nil || tc.writeError != nil {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if tc.shouldModify {
require.Equal(t, 1, len(specMock.ModifyCalls()))
} else {
require.Equal(t, 0, len(specMock.ModifyCalls()))
}
if tc.shouldFlush {
require.Equal(t, 1, len(specMock.FlushCalls()))
} else {
require.Equal(t, 0, len(specMock.FlushCalls()))
}
if tc.shouldForward {
require.Equal(t, 1, len(runtimeMock.ExecCalls()))
} else {
require.Equal(t, 0, len(runtimeMock.ExecCalls()))
}
})
}
}