mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 00:08:11 +00:00
Import modifying runtime abstraction from experimental runtime
This change imports the modifying runtime abstraction from the experimental branch. This encapsulates the checks for whether modification is required, and forwards the loaded spec to the specified modifier. This allows for the same code to be reused when performing more complex modifications. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
bf8c3bab72
commit
4177fddcc4
@ -167,11 +167,11 @@ func TestDuplicateHook(t *testing.T) {
|
||||
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
|
||||
}
|
||||
|
||||
// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for
|
||||
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
|
||||
// testing.
|
||||
func addNVIDIAHook(spec *specs.Spec) error {
|
||||
r := nvidiaContainerRuntime{logger: logger.Logger}
|
||||
return r.addNVIDIAHook(spec)
|
||||
m := addHookModifier{logger: logger.Logger}
|
||||
return m.Modify(spec)
|
||||
}
|
||||
|
||||
func (c testConfig) getRuntimeSpec() (specs.Spec, error) {
|
||||
|
@ -17,88 +17,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/runtime"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally
|
||||
// modifying the specified OCI specification before invoking the runtime.
|
||||
type nvidiaContainerRuntime struct {
|
||||
logger *log.Logger
|
||||
runtime oci.Runtime
|
||||
ociSpec oci.Spec
|
||||
// newNvidiaContainerRuntime is a constructor for a standard runtime shim. This uses
|
||||
// a ModifyingRuntimeWrapper to apply the required modifications before execing to the
|
||||
// specified low-level runtime
|
||||
func newNvidiaContainerRuntime(logger *log.Logger, lowlevelRuntime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
|
||||
modifier := addHookModifier{logger: logger}
|
||||
|
||||
r := runtime.NewModifyingRuntimeWrapper(
|
||||
logger,
|
||||
lowlevelRuntime,
|
||||
ociSpec,
|
||||
modifier,
|
||||
)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
var _ oci.Runtime = (*nvidiaContainerRuntime)(nil)
|
||||
|
||||
// newNvidiaContainerRuntime is a constructor for a standard runtime shim.
|
||||
func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
|
||||
r := nvidiaContainerRuntime{
|
||||
logger: logger,
|
||||
runtime: runtime,
|
||||
ociSpec: ociSpec,
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
// addHookModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
|
||||
// prestart hook. If the hook is already present, no modification is made.
|
||||
type addHookModifier struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications
|
||||
// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then
|
||||
// forwarded to the underlying runtime's Exec method.
|
||||
func (r nvidiaContainerRuntime) Exec(args []string) error {
|
||||
if r.modificationRequired(args) {
|
||||
err := r.modifyOCISpec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying OCI spec: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Println("Forwarding command to runtime")
|
||||
return r.runtime.Exec(args)
|
||||
}
|
||||
|
||||
// modificationRequired checks the intput arguments to determine whether a modification
|
||||
// to the OCI spec is required.
|
||||
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
|
||||
if oci.HasCreateSubcommand(args) {
|
||||
r.logger.Infof("'create' command detected; modification required")
|
||||
return true
|
||||
}
|
||||
|
||||
r.logger.Infof("No modification required")
|
||||
return false
|
||||
}
|
||||
|
||||
// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime
|
||||
// struct. The spec is modified in-place and written to the same file as the input after
|
||||
// modifcationas are applied.
|
||||
func (r nvidiaContainerRuntime) modifyOCISpec() error {
|
||||
err := r.ociSpec.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading OCI specification for modification: %v", err)
|
||||
}
|
||||
|
||||
err = r.ociSpec.Modify(r.addNVIDIAHook)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err)
|
||||
}
|
||||
|
||||
err = r.ociSpec.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing modified OCI specification: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNVIDIAHook modifies the specified OCI specification in-place, inserting a
|
||||
// prestart hook.
|
||||
func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
|
||||
// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
|
||||
// as a prestart hook.
|
||||
func (m addHookModifier) Modify(spec *specs.Spec) error {
|
||||
path, err := exec.LookPath("nvidia-container-runtime-hook")
|
||||
if err != nil {
|
||||
path = hookDefaultFilePath
|
||||
@ -108,7 +61,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Printf("prestart hook path: %s\n", path)
|
||||
m.logger.Printf("prestart hook path: %s\n", path)
|
||||
|
||||
args := []string{path}
|
||||
if spec.Hooks == nil {
|
||||
@ -118,7 +71,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
|
||||
if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
|
||||
continue
|
||||
}
|
||||
r.logger.Println("existing nvidia prestart hook in OCI spec file")
|
||||
m.logger.Println("existing nvidia prestart hook in OCI spec file")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -29,9 +29,8 @@ import (
|
||||
|
||||
func TestAddNvidiaHook(t *testing.T) {
|
||||
logger, logHook := testlog.NewNullLogger()
|
||||
shim := nvidiaContainerRuntime{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
mockRuntime := &oci.RuntimeMock{}
|
||||
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
@ -75,7 +74,16 @@ func TestAddNvidiaHook(t *testing.T) {
|
||||
numPrestartHooks = len(tc.spec.Hooks.Prestart)
|
||||
}
|
||||
|
||||
err := shim.addNVIDIAHook(tc.spec)
|
||||
shim, err := newNvidiaContainerRuntime(
|
||||
logger,
|
||||
mockRuntime,
|
||||
oci.NewMemorySpec(tc.spec),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = shim.Exec([]string{"runtime", "create"})
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.errorPrefix == "" {
|
||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||
@ -106,45 +114,39 @@ func TestAddNvidiaHook(t *testing.T) {
|
||||
func TestNvidiaContainerRuntime(t *testing.T) {
|
||||
logger, hook := testlog.NewNullLogger()
|
||||
|
||||
mockRuntime := &oci.RuntimeMock{}
|
||||
|
||||
testCases := []struct {
|
||||
shim nvidiaContainerRuntime
|
||||
shouldModify bool
|
||||
args []string
|
||||
modifyError error
|
||||
writeError error
|
||||
}{
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"--bundle=create"},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"--bundle", "create"},
|
||||
shouldModify: false,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
modifyError: fmt.Errorf("error modifying"),
|
||||
shouldModify: true,
|
||||
},
|
||||
{
|
||||
shim: nvidiaContainerRuntime{},
|
||||
args: []string{"create"},
|
||||
writeError: fmt.Errorf("error writing"),
|
||||
shouldModify: true,
|
||||
@ -152,10 +154,8 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
tc.shim.logger = logger
|
||||
hook.Reset()
|
||||
|
||||
ociMock := &oci.SpecMock{
|
||||
specMock := &oci.SpecMock{
|
||||
ModifyFunc: func(specModifier oci.SpecModifier) error {
|
||||
return tc.modifyError
|
||||
},
|
||||
@ -163,12 +163,11 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
||||
return tc.writeError
|
||||
},
|
||||
}
|
||||
require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)
|
||||
|
||||
tc.shim.ociSpec = ociMock
|
||||
tc.shim.runtime = &MockShim{}
|
||||
shim, err := newNvidiaContainerRuntime(logger, mockRuntime, specMock)
|
||||
require.NoError(t, err)
|
||||
|
||||
err := tc.shim.Exec(tc.args)
|
||||
err = shim.Exec(tc.args)
|
||||
if tc.modifyError != nil || tc.writeError != nil {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
@ -176,28 +175,16 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
||||
}
|
||||
|
||||
if tc.shouldModify {
|
||||
require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
|
||||
require.Equal(t, 1, len(specMock.ModifyCalls()), "%d: %v", i, tc)
|
||||
} else {
|
||||
require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
|
||||
require.Equal(t, 0, len(specMock.ModifyCalls()), "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
writeExpected := tc.shouldModify && tc.modifyError == nil
|
||||
if writeExpected {
|
||||
require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc)
|
||||
require.Equal(t, 1, len(specMock.FlushCalls()), "%d: %v", i, tc)
|
||||
} else {
|
||||
require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc)
|
||||
require.Equal(t, 0, len(specMock.FlushCalls()), "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MockShim struct {
|
||||
called bool
|
||||
args []string
|
||||
returnError error
|
||||
}
|
||||
|
||||
func (m *MockShim) Exec(args []string) error {
|
||||
m.called = true
|
||||
m.args = args
|
||||
return m.returnError
|
||||
}
|
||||
|
@ -23,52 +23,27 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ociSpecFileName = "config.json"
|
||||
dockerRuncExecutableName = "docker-runc"
|
||||
runcExecutableName = "runc"
|
||||
)
|
||||
|
||||
// newRuntime is a factory method that constructs a runtime based on the selected configuration.
|
||||
func newRuntime(argv []string) (oci.Runtime, error) {
|
||||
ociSpec, err := newOCISpec(argv)
|
||||
ociSpec, err := oci.NewSpec(logger.Logger, argv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
|
||||
}
|
||||
|
||||
runc, err := newRuncRuntime()
|
||||
lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName}
|
||||
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger.Logger, lowLevelRuntimeCandidates)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing runc runtime: %v", err)
|
||||
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
|
||||
}
|
||||
|
||||
r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec)
|
||||
r, err := newNvidiaContainerRuntime(logger.Logger, lowLevelRuntime, ociSpec)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// newOCISpec constructs an OCI spec for the provided arguments
|
||||
func newOCISpec(argv []string) (oci.Spec, error) {
|
||||
bundleDir, err := oci.GetBundleDir(argv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
|
||||
}
|
||||
logger.Infof("Using bundle directory: %v", bundleDir)
|
||||
|
||||
ociSpecPath := oci.GetSpecFilePath(bundleDir)
|
||||
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
|
||||
|
||||
ociSpec := oci.NewSpecFromFile(ociSpecPath)
|
||||
|
||||
return ociSpec, nil
|
||||
}
|
||||
|
||||
// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
|
||||
func newRuncRuntime() (oci.Runtime, error) {
|
||||
return oci.NewLowLevelRuntimeWithLogger(
|
||||
logger.Logger,
|
||||
dockerRuncExecutableName,
|
||||
runcExecutableName,
|
||||
)
|
||||
}
|
||||
|
@ -25,19 +25,14 @@ import (
|
||||
|
||||
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
|
||||
// The executable specified is taken from the list of supplied candidates, with the first match
|
||||
// present in the PATH being selected.
|
||||
func NewLowLevelRuntime(candidates ...string) (Runtime, error) {
|
||||
return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...)
|
||||
}
|
||||
|
||||
// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger.
|
||||
func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) {
|
||||
// present in the PATH being selected. A logger is also specified.
|
||||
func NewLowLevelRuntime(logger *log.Logger, candidates []string) (Runtime, error) {
|
||||
runtimePath, err := findRuntime(logger, candidates)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error locating runtime: %v", err)
|
||||
}
|
||||
|
||||
return NewRuntimeForPathWithLogger(logger, runtimePath)
|
||||
return NewRuntimeForPath(logger, runtimePath)
|
||||
}
|
||||
|
||||
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.
|
||||
|
@ -34,13 +34,8 @@ type pathRuntime struct {
|
||||
|
||||
var _ Runtime = (*pathRuntime)(nil)
|
||||
|
||||
// NewRuntimeForPath creates a Runtime for the specified path with the standard logger
|
||||
func NewRuntimeForPath(path string) (Runtime, error) {
|
||||
return NewRuntimeForPathWithLogger(log.StandardLogger(), path)
|
||||
}
|
||||
|
||||
// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path
|
||||
func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) {
|
||||
// NewRuntimeForPath creates a Runtime for the specified logger and path
|
||||
func NewRuntimeForPath(logger *log.Logger, path string) (Runtime, error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
|
||||
|
@ -24,19 +24,21 @@ import (
|
||||
)
|
||||
|
||||
func TestPathRuntimeConstructor(t *testing.T) {
|
||||
r, err := NewRuntimeForPath("////an/invalid/path")
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
|
||||
r, err := NewRuntimeForPath(logger, "////an/invalid/path")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewRuntimeForPath("/tmp")
|
||||
r, err = NewRuntimeForPath(logger, "/tmp")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewRuntimeForPath("/dev/null")
|
||||
r, err = NewRuntimeForPath(logger, "/dev/null")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
|
||||
r, err = NewRuntimeForPath("/bin/sh")
|
||||
r, err = NewRuntimeForPath(logger, "/bin/sh")
|
||||
require.NoError(t, err)
|
||||
|
||||
f, ok := r.(*pathRuntime)
|
||||
|
@ -17,15 +17,20 @@
|
||||
package oci
|
||||
|
||||
import (
|
||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
||||
"fmt"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an
|
||||
// error. The intention is that the function would modify the spec in-place.
|
||||
type SpecModifier func(*oci.Spec) error
|
||||
// SpecModifier defines an interace for modifying a (raw) OCI spec
|
||||
type SpecModifier interface {
|
||||
// Modify is a method that accepts a pointer to an OCI Srec and returns an
|
||||
// error. The intention is that the function would modify the spec in-place.
|
||||
Modify(*specs.Spec) error
|
||||
}
|
||||
|
||||
//go:generate moq -stub -out spec_mock.go . Spec
|
||||
|
||||
// Spec defines the operations to be performed on an OCI specification
|
||||
type Spec interface {
|
||||
Load() error
|
||||
@ -33,3 +38,20 @@ type Spec interface {
|
||||
Modify(SpecModifier) error
|
||||
LookupEnv(string) (string, bool)
|
||||
}
|
||||
|
||||
// NewSpec creates fileSpec based on the command line arguments passed to the
|
||||
// application using the specified logger.
|
||||
func NewSpec(logger *logrus.Logger, args []string) (Spec, error) {
|
||||
bundleDir, err := GetBundleDir(args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting bundle directory: %v", err)
|
||||
}
|
||||
logger.Infof("Using bundle directory: %v", bundleDir)
|
||||
|
||||
ociSpecPath := GetSpecFilePath(bundleDir)
|
||||
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
|
||||
|
||||
ociSpec := NewFileSpec(ociSpecPath)
|
||||
|
||||
return ociSpec, nil
|
||||
}
|
||||
|
@ -21,37 +21,21 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
)
|
||||
|
||||
type fileSpec struct {
|
||||
*oci.Spec
|
||||
memorySpec
|
||||
path string
|
||||
}
|
||||
|
||||
var _ Spec = (*fileSpec)(nil)
|
||||
|
||||
// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the
|
||||
// application
|
||||
func NewSpecFromArgs(args []string) (Spec, string, error) {
|
||||
bundleDir, err := GetBundleDir(args)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error getting bundle directory: %v", err)
|
||||
}
|
||||
|
||||
ociSpecPath := GetSpecFilePath(bundleDir)
|
||||
|
||||
ociSpec := NewSpecFromFile(ociSpecPath)
|
||||
|
||||
return ociSpec, bundleDir, nil
|
||||
}
|
||||
|
||||
// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec.
|
||||
// NewFileSpec creates an object that encapsulates a file-backed OCI spec.
|
||||
// This can be used to read from the file, modify the spec, and write to the
|
||||
// same file.
|
||||
func NewSpecFromFile(filepath string) Spec {
|
||||
func NewFileSpec(filepath string) Spec {
|
||||
oci := fileSpec{
|
||||
path: filepath,
|
||||
}
|
||||
@ -68,29 +52,31 @@ func (s *fileSpec) Load() error {
|
||||
}
|
||||
defer specFile.Close()
|
||||
|
||||
return s.loadFrom(specFile)
|
||||
}
|
||||
|
||||
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
|
||||
func (s *fileSpec) loadFrom(reader io.Reader) error {
|
||||
decoder := json.NewDecoder(reader)
|
||||
|
||||
var spec oci.Spec
|
||||
err := decoder.Decode(&spec)
|
||||
spec, err := loadFrom(specFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading OCI specification: %v", err)
|
||||
return fmt.Errorf("error loading OCI specification from file: %v", err)
|
||||
}
|
||||
|
||||
s.Spec = &spec
|
||||
s.Spec = spec
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify applies the specified SpecModifier to the stored OCI specification.
|
||||
func (s *fileSpec) Modify(f SpecModifier) error {
|
||||
if s.Spec == nil {
|
||||
return fmt.Errorf("no spec loaded for modification")
|
||||
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
|
||||
func loadFrom(reader io.Reader) (*specs.Spec, error) {
|
||||
decoder := json.NewDecoder(reader)
|
||||
|
||||
var spec specs.Spec
|
||||
|
||||
err := decoder.Decode(&spec)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading OCI specification: %v", err)
|
||||
}
|
||||
return f(s.Spec)
|
||||
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
// Modify applies the specified SpecModifier to the stored OCI specification.
|
||||
func (s *fileSpec) Modify(m SpecModifier) error {
|
||||
return s.memorySpec.Modify(m)
|
||||
}
|
||||
|
||||
// Flush writes the stored OCI specification to the filepath specifed by the path member.
|
||||
@ -106,48 +92,20 @@ func (s fileSpec) Flush() error {
|
||||
}
|
||||
defer specFile.Close()
|
||||
|
||||
return s.flushTo(specFile)
|
||||
return flushTo(s.Spec, specFile)
|
||||
}
|
||||
|
||||
// flushTo writes the stored OCI specification to the specified io.Writer.
|
||||
func (s fileSpec) flushTo(writer io.Writer) error {
|
||||
if s.Spec == nil {
|
||||
func flushTo(spec *specs.Spec, writer io.Writer) error {
|
||||
if spec == nil {
|
||||
return nil
|
||||
}
|
||||
encoder := json.NewEncoder(writer)
|
||||
|
||||
err := encoder.Encode(s.Spec)
|
||||
err := encoder.Encode(spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing OCI specification: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
|
||||
// retrieves the value of the environment variable named
|
||||
// by the key. If the variable is present in the environment the
|
||||
// value (which may be empty) is returned and the boolean is true.
|
||||
// Otherwise the returned value will be empty and the boolean will
|
||||
// be false.
|
||||
func (s fileSpec) LookupEnv(key string) (string, bool) {
|
||||
if s.Spec == nil || s.Spec.Process == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, env := range s.Spec.Process.Env {
|
||||
if !strings.HasPrefix(env, key) {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if parts[0] == key {
|
||||
if len(parts) < 2 {
|
||||
return "", true
|
||||
}
|
||||
return parts[1], true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
@ -25,111 +25,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLookupEnv(t *testing.T) {
|
||||
const envName = "TEST_ENV"
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
expectedValue string
|
||||
expectedExits bool
|
||||
}{
|
||||
{
|
||||
// nil spec
|
||||
spec: nil,
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// nil process
|
||||
spec: &specs.Spec{},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// nil env
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// empty env
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// different env set
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// same prefix
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// same suffix
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// set blank
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV="}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set no-equals
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set value
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
|
||||
},
|
||||
expectedValue: "something",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set with equals
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
|
||||
},
|
||||
expectedValue: "something=somethingelse",
|
||||
expectedExits: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := fileSpec{
|
||||
Spec: tc.spec,
|
||||
}
|
||||
|
||||
value, exists := spec.LookupEnv(envName)
|
||||
|
||||
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
|
||||
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFrom(t *testing.T) {
|
||||
testCases := []struct {
|
||||
contents []byte
|
||||
@ -148,8 +43,8 @@ func TestLoadFrom(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := fileSpec{}
|
||||
err := spec.loadFrom(bytes.NewReader(tc.contents))
|
||||
var spec *specs.Spec
|
||||
spec, err := loadFrom(bytes.NewReader(tc.contents))
|
||||
|
||||
if tc.isError {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
@ -158,9 +53,9 @@ func TestLoadFrom(t *testing.T) {
|
||||
}
|
||||
|
||||
if tc.spec == nil {
|
||||
require.Nil(t, spec.Spec, "%d: %v", i, tc)
|
||||
require.Nil(t, spec, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc)
|
||||
require.EqualValues(t, tc.spec, spec, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -183,8 +78,7 @@ func TestFlushTo(t *testing.T) {
|
||||
for i, tc := range testCases {
|
||||
buffer := bytes.Buffer{}
|
||||
|
||||
spec := fileSpec{Spec: tc.spec}
|
||||
err := spec.flushTo(&buffer)
|
||||
err := flushTo(tc.spec, &buffer)
|
||||
|
||||
if tc.isError {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
@ -196,53 +90,10 @@ func TestFlushTo(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add a simple test for a writer that returns an error when writing
|
||||
spec := fileSpec{Spec: &specs.Spec{}}
|
||||
err := spec.flushTo(errorWriter{})
|
||||
err := flushTo(&specs.Spec{}, errorWriter{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestModify(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
modifierError error
|
||||
}{
|
||||
{
|
||||
spec: nil,
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
modifierError: fmt.Errorf("error in modifier"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := fileSpec{Spec: tc.spec}
|
||||
|
||||
modifier := func(spec *specs.Spec) error {
|
||||
if tc.modifierError == nil {
|
||||
spec.Version = "updated"
|
||||
}
|
||||
return tc.modifierError
|
||||
}
|
||||
|
||||
err := spec.Modify(modifier)
|
||||
|
||||
if tc.spec == nil {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else if tc.modifierError != nil {
|
||||
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
|
||||
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoError(t, err, "%d: %v", i, tc)
|
||||
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// errorWriter implements the io.Writer interface, always returning an error when
|
||||
// writing.
|
||||
type errorWriter struct{}
|
||||
|
83
internal/oci/spec_memory.go
Normal file
83
internal/oci/spec_memory.go
Normal file
@ -0,0 +1,83 @@
|
||||
/**
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
)
|
||||
|
||||
type memorySpec struct {
|
||||
*specs.Spec
|
||||
}
|
||||
|
||||
// NewMemorySpec creates a Spec instance from the specified OCI spec
|
||||
func NewMemorySpec(spec *specs.Spec) Spec {
|
||||
s := memorySpec{
|
||||
Spec: spec,
|
||||
}
|
||||
|
||||
return &s
|
||||
}
|
||||
|
||||
// Load is a no-op for the memorySpec spec
|
||||
func (s *memorySpec) Load() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush is a no-op for the memorySpec spec
|
||||
func (s *memorySpec) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify applies the specified SpecModifier to the stored OCI specification.
|
||||
func (s *memorySpec) Modify(m SpecModifier) error {
|
||||
if s.Spec == nil {
|
||||
return fmt.Errorf("cannot modify nil spec")
|
||||
}
|
||||
return m.Modify(s.Spec)
|
||||
}
|
||||
|
||||
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
|
||||
// retrieves the value of the environment variable named
|
||||
// by the key. If the variable is present in the environment the
|
||||
// value (which may be empty) is returned and the boolean is true.
|
||||
// Otherwise the returned value will be empty and the boolean will
|
||||
// be false.
|
||||
func (s memorySpec) LookupEnv(key string) (string, bool) {
|
||||
if s.Spec == nil || s.Spec.Process == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, env := range s.Spec.Process.Env {
|
||||
if !strings.HasPrefix(env, key) {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if parts[0] == key {
|
||||
if len(parts) < 2 {
|
||||
return "", true
|
||||
}
|
||||
return parts[1], true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
178
internal/oci/spec_memory_test.go
Normal file
178
internal/oci/spec_memory_test.go
Normal file
@ -0,0 +1,178 @@
|
||||
/**
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package oci
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLookupEnv(t *testing.T) {
|
||||
const envName = "TEST_ENV"
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
expectedValue string
|
||||
expectedExits bool
|
||||
}{
|
||||
{
|
||||
// nil spec
|
||||
spec: nil,
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// nil process
|
||||
spec: &specs.Spec{},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// nil env
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// empty env
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// different env set
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// same prefix
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// same suffix
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: false,
|
||||
},
|
||||
{
|
||||
// set blank
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV="}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set no-equals
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV"}},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set value
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
|
||||
},
|
||||
expectedValue: "something",
|
||||
expectedExits: true,
|
||||
},
|
||||
{
|
||||
// set with equals
|
||||
spec: &specs.Spec{
|
||||
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
|
||||
},
|
||||
expectedValue: "something=somethingelse",
|
||||
expectedExits: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := memorySpec{
|
||||
Spec: tc.spec,
|
||||
}
|
||||
|
||||
value, exists := spec.LookupEnv(envName)
|
||||
|
||||
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
|
||||
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModify(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
spec *specs.Spec
|
||||
modifierError error
|
||||
}{
|
||||
{
|
||||
spec: nil,
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
},
|
||||
{
|
||||
spec: &specs.Spec{},
|
||||
modifierError: fmt.Errorf("error in modifier"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
spec := NewMemorySpec(tc.spec).(*memorySpec)
|
||||
|
||||
err := spec.Modify(modifier{tc.modifierError})
|
||||
|
||||
if tc.spec == nil {
|
||||
require.Error(t, err, "%d: %v", i, tc)
|
||||
} else if tc.modifierError != nil {
|
||||
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
|
||||
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoError(t, err, "%d: %v", i, tc)
|
||||
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Ideally we would generated a mock for the SpecModifer too. This causes
|
||||
// an import cycle and we define a local type as a work-around.
|
||||
type modifier struct {
|
||||
modifierError error
|
||||
}
|
||||
|
||||
func (m modifier) Modify(spec *specs.Spec) error {
|
||||
if m.modifierError == nil {
|
||||
spec.Version = "updated"
|
||||
}
|
||||
return m.modifierError
|
||||
}
|
@ -20,7 +20,7 @@ func TestMaintainSpec(t *testing.T) {
|
||||
for _, f := range files {
|
||||
inputSpecPath := filepath.Join(moduleRoot, "test/input", f)
|
||||
|
||||
spec := NewSpecFromFile(inputSpecPath).(*fileSpec)
|
||||
spec := NewFileSpec(inputSpecPath).(*fileSpec)
|
||||
|
||||
spec.Load()
|
||||
|
||||
|
81
internal/runtime/runtime_modifier.go
Normal file
81
internal/runtime/runtime_modifier.go
Normal file
@ -0,0 +1,81 @@
|
||||
/*
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type modifyingRuntimeWrapper struct {
|
||||
logger *log.Logger
|
||||
runtime oci.Runtime
|
||||
ociSpec oci.Spec
|
||||
modifier oci.SpecModifier
|
||||
}
|
||||
|
||||
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
|
||||
|
||||
// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
|
||||
// before invoking the wrapped runtime.
|
||||
func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime {
|
||||
rt := modifyingRuntimeWrapper{
|
||||
logger: logger,
|
||||
runtime: runtime,
|
||||
ociSpec: spec,
|
||||
modifier: modifier,
|
||||
}
|
||||
return &rt
|
||||
}
|
||||
|
||||
// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing
|
||||
// into the wrapped runtime.
|
||||
func (r *modifyingRuntimeWrapper) Exec(args []string) error {
|
||||
if oci.HasCreateSubcommand(args) {
|
||||
err := r.modify()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not apply required modification to OCI specification: %v", err)
|
||||
}
|
||||
r.logger.Infof("Applied required modification to OCI specification")
|
||||
} else {
|
||||
r.logger.Infof("No modification of OCI specification required")
|
||||
}
|
||||
|
||||
r.logger.Infof("Forwarding command to runtime")
|
||||
return r.runtime.Exec(args)
|
||||
}
|
||||
|
||||
// modify loads, modifies, and flushes the OCI specification using the defined Modifier
|
||||
func (r *modifyingRuntimeWrapper) modify() error {
|
||||
err := r.ociSpec.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading OCI specification for modification: %v", err)
|
||||
}
|
||||
|
||||
err = r.ociSpec.Modify(r.modifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying OCI spec: %v", err)
|
||||
}
|
||||
|
||||
err = r.ociSpec.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing modified OCI specification: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
116
internal/runtime/runtime_modifier_test.go
Normal file
116
internal/runtime/runtime_modifier_test.go
Normal file
@ -0,0 +1,116 @@
|
||||
/*
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
logger, hook := testlog.NewNullLogger()
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
shouldModify bool
|
||||
shouldFlush bool
|
||||
shouldForward bool
|
||||
args []string
|
||||
modifyError error
|
||||
writeError error
|
||||
}{
|
||||
{
|
||||
description: "no args forwards",
|
||||
shouldModify: false,
|
||||
shouldFlush: false,
|
||||
shouldForward: true,
|
||||
},
|
||||
{
|
||||
description: "create modifies",
|
||||
args: []string{"create"},
|
||||
shouldModify: true,
|
||||
shouldFlush: true,
|
||||
shouldForward: true,
|
||||
},
|
||||
{
|
||||
description: "modify error does not write or forward",
|
||||
args: []string{"create"},
|
||||
modifyError: fmt.Errorf("error modifying"),
|
||||
shouldModify: true,
|
||||
shouldFlush: false,
|
||||
shouldForward: false,
|
||||
},
|
||||
{
|
||||
description: "write error does not forward",
|
||||
args: []string{"create"},
|
||||
writeError: fmt.Errorf("error writing"),
|
||||
shouldModify: true,
|
||||
shouldFlush: true,
|
||||
shouldForward: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
hook.Reset()
|
||||
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
runtimeMock := &oci.RuntimeMock{}
|
||||
specMock := &oci.SpecMock{
|
||||
ModifyFunc: func(specModifier oci.SpecModifier) error {
|
||||
return tc.modifyError
|
||||
},
|
||||
FlushFunc: func() error {
|
||||
return tc.writeError
|
||||
},
|
||||
}
|
||||
|
||||
shim := NewModifyingRuntimeWrapper(
|
||||
logger,
|
||||
runtimeMock,
|
||||
specMock,
|
||||
// TODO: We should test the interactions with the SpecModifier too
|
||||
nil)
|
||||
|
||||
err := shim.Exec(tc.args)
|
||||
if tc.modifyError != nil || tc.writeError != nil {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tc.shouldModify {
|
||||
require.Equal(t, 1, len(specMock.ModifyCalls()))
|
||||
} else {
|
||||
require.Equal(t, 0, len(specMock.ModifyCalls()))
|
||||
}
|
||||
if tc.shouldFlush {
|
||||
require.Equal(t, 1, len(specMock.FlushCalls()))
|
||||
} else {
|
||||
require.Equal(t, 0, len(specMock.FlushCalls()))
|
||||
}
|
||||
if tc.shouldForward {
|
||||
require.Equal(t, 1, len(runtimeMock.ExecCalls()))
|
||||
} else {
|
||||
require.Equal(t, 0, len(runtimeMock.ExecCalls()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user