mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-04-22 23:24:18 +00:00
Import modifying runtime abstraction from experimental runtime
This change imports the modifying runtime abstraction from the experimental branch. This encapsulates the checks for whether modification is required, and forwards the loaded spec to the specified modifier. This allows for the same code to be reused when performing more complex modifications. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
bf8c3bab72
commit
4177fddcc4
@ -167,11 +167,11 @@ func TestDuplicateHook(t *testing.T) {
|
|||||||
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
|
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for
|
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
|
||||||
// testing.
|
// testing.
|
||||||
func addNVIDIAHook(spec *specs.Spec) error {
|
func addNVIDIAHook(spec *specs.Spec) error {
|
||||||
r := nvidiaContainerRuntime{logger: logger.Logger}
|
m := addHookModifier{logger: logger.Logger}
|
||||||
return r.addNVIDIAHook(spec)
|
return m.Modify(spec)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c testConfig) getRuntimeSpec() (specs.Spec, error) {
|
func (c testConfig) getRuntimeSpec() (specs.Spec, error) {
|
||||||
|
@ -17,88 +17,41 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/runtime"
|
||||||
"github.com/opencontainers/runtime-spec/specs-go"
|
"github.com/opencontainers/runtime-spec/specs-go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally
|
// newNvidiaContainerRuntime is a constructor for a standard runtime shim. This uses
|
||||||
// modifying the specified OCI specification before invoking the runtime.
|
// a ModifyingRuntimeWrapper to apply the required modifications before execing to the
|
||||||
type nvidiaContainerRuntime struct {
|
// specified low-level runtime
|
||||||
logger *log.Logger
|
func newNvidiaContainerRuntime(logger *log.Logger, lowlevelRuntime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
|
||||||
runtime oci.Runtime
|
modifier := addHookModifier{logger: logger}
|
||||||
ociSpec oci.Spec
|
|
||||||
|
r := runtime.NewModifyingRuntimeWrapper(
|
||||||
|
logger,
|
||||||
|
lowlevelRuntime,
|
||||||
|
ociSpec,
|
||||||
|
modifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ oci.Runtime = (*nvidiaContainerRuntime)(nil)
|
// addHookModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
|
||||||
|
// prestart hook. If the hook is already present, no modification is made.
|
||||||
// newNvidiaContainerRuntime is a constructor for a standard runtime shim.
|
type addHookModifier struct {
|
||||||
func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
|
logger *log.Logger
|
||||||
r := nvidiaContainerRuntime{
|
|
||||||
logger: logger,
|
|
||||||
runtime: runtime,
|
|
||||||
ociSpec: ociSpec,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &r, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications
|
// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
|
||||||
// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then
|
// as a prestart hook.
|
||||||
// forwarded to the underlying runtime's Exec method.
|
func (m addHookModifier) Modify(spec *specs.Spec) error {
|
||||||
func (r nvidiaContainerRuntime) Exec(args []string) error {
|
|
||||||
if r.modificationRequired(args) {
|
|
||||||
err := r.modifyOCISpec()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error modifying OCI spec: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.logger.Println("Forwarding command to runtime")
|
|
||||||
return r.runtime.Exec(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// modificationRequired checks the intput arguments to determine whether a modification
|
|
||||||
// to the OCI spec is required.
|
|
||||||
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
|
|
||||||
if oci.HasCreateSubcommand(args) {
|
|
||||||
r.logger.Infof("'create' command detected; modification required")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
r.logger.Infof("No modification required")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime
|
|
||||||
// struct. The spec is modified in-place and written to the same file as the input after
|
|
||||||
// modifcationas are applied.
|
|
||||||
func (r nvidiaContainerRuntime) modifyOCISpec() error {
|
|
||||||
err := r.ociSpec.Load()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error loading OCI specification for modification: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.ociSpec.Modify(r.addNVIDIAHook)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.ociSpec.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error writing modified OCI specification: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNVIDIAHook modifies the specified OCI specification in-place, inserting a
|
|
||||||
// prestart hook.
|
|
||||||
func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
|
|
||||||
path, err := exec.LookPath("nvidia-container-runtime-hook")
|
path, err := exec.LookPath("nvidia-container-runtime-hook")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
path = hookDefaultFilePath
|
path = hookDefaultFilePath
|
||||||
@ -108,7 +61,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.logger.Printf("prestart hook path: %s\n", path)
|
m.logger.Printf("prestart hook path: %s\n", path)
|
||||||
|
|
||||||
args := []string{path}
|
args := []string{path}
|
||||||
if spec.Hooks == nil {
|
if spec.Hooks == nil {
|
||||||
@ -118,7 +71,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
|
|||||||
if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
|
if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
r.logger.Println("existing nvidia prestart hook in OCI spec file")
|
m.logger.Println("existing nvidia prestart hook in OCI spec file")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,9 +29,8 @@ import (
|
|||||||
|
|
||||||
func TestAddNvidiaHook(t *testing.T) {
|
func TestAddNvidiaHook(t *testing.T) {
|
||||||
logger, logHook := testlog.NewNullLogger()
|
logger, logHook := testlog.NewNullLogger()
|
||||||
shim := nvidiaContainerRuntime{
|
|
||||||
logger: logger,
|
mockRuntime := &oci.RuntimeMock{}
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
spec *specs.Spec
|
spec *specs.Spec
|
||||||
@ -75,7 +74,16 @@ func TestAddNvidiaHook(t *testing.T) {
|
|||||||
numPrestartHooks = len(tc.spec.Hooks.Prestart)
|
numPrestartHooks = len(tc.spec.Hooks.Prestart)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := shim.addNVIDIAHook(tc.spec)
|
shim, err := newNvidiaContainerRuntime(
|
||||||
|
logger,
|
||||||
|
mockRuntime,
|
||||||
|
oci.NewMemorySpec(tc.spec),
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = shim.Exec([]string{"runtime", "create"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
if tc.errorPrefix == "" {
|
if tc.errorPrefix == "" {
|
||||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||||
@ -106,45 +114,39 @@ func TestAddNvidiaHook(t *testing.T) {
|
|||||||
func TestNvidiaContainerRuntime(t *testing.T) {
|
func TestNvidiaContainerRuntime(t *testing.T) {
|
||||||
logger, hook := testlog.NewNullLogger()
|
logger, hook := testlog.NewNullLogger()
|
||||||
|
|
||||||
|
mockRuntime := &oci.RuntimeMock{}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
shim nvidiaContainerRuntime
|
|
||||||
shouldModify bool
|
shouldModify bool
|
||||||
args []string
|
args []string
|
||||||
modifyError error
|
modifyError error
|
||||||
writeError error
|
writeError error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
shim: nvidiaContainerRuntime{},
|
|
||||||
shouldModify: false,
|
shouldModify: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
shim: nvidiaContainerRuntime{},
|
|
||||||
args: []string{"create"},
|
args: []string{"create"},
|
||||||
shouldModify: true,
|
shouldModify: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
shim: nvidiaContainerRuntime{},
|
|
||||||
args: []string{"--bundle=create"},
|
args: []string{"--bundle=create"},
|
||||||
shouldModify: false,
|
shouldModify: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
shim: nvidiaContainerRuntime{},
|
|
||||||
args: []string{"--bundle", "create"},
|
args: []string{"--bundle", "create"},
|
||||||
shouldModify: false,
|
shouldModify: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
shim: nvidiaContainerRuntime{},
|
|
||||||
args: []string{"create"},
|
args: []string{"create"},
|
||||||
shouldModify: true,
|
shouldModify: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
shim: nvidiaContainerRuntime{},
|
|
||||||
args: []string{"create"},
|
args: []string{"create"},
|
||||||
modifyError: fmt.Errorf("error modifying"),
|
modifyError: fmt.Errorf("error modifying"),
|
||||||
shouldModify: true,
|
shouldModify: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
shim: nvidiaContainerRuntime{},
|
|
||||||
args: []string{"create"},
|
args: []string{"create"},
|
||||||
writeError: fmt.Errorf("error writing"),
|
writeError: fmt.Errorf("error writing"),
|
||||||
shouldModify: true,
|
shouldModify: true,
|
||||||
@ -152,10 +154,8 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
tc.shim.logger = logger
|
|
||||||
hook.Reset()
|
hook.Reset()
|
||||||
|
specMock := &oci.SpecMock{
|
||||||
ociMock := &oci.SpecMock{
|
|
||||||
ModifyFunc: func(specModifier oci.SpecModifier) error {
|
ModifyFunc: func(specModifier oci.SpecModifier) error {
|
||||||
return tc.modifyError
|
return tc.modifyError
|
||||||
},
|
},
|
||||||
@ -163,12 +163,11 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
|||||||
return tc.writeError
|
return tc.writeError
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)
|
|
||||||
|
|
||||||
tc.shim.ociSpec = ociMock
|
shim, err := newNvidiaContainerRuntime(logger, mockRuntime, specMock)
|
||||||
tc.shim.runtime = &MockShim{}
|
require.NoError(t, err)
|
||||||
|
|
||||||
err := tc.shim.Exec(tc.args)
|
err = shim.Exec(tc.args)
|
||||||
if tc.modifyError != nil || tc.writeError != nil {
|
if tc.modifyError != nil || tc.writeError != nil {
|
||||||
require.Error(t, err, "%d: %v", i, tc)
|
require.Error(t, err, "%d: %v", i, tc)
|
||||||
} else {
|
} else {
|
||||||
@ -176,28 +175,16 @@ func TestNvidiaContainerRuntime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tc.shouldModify {
|
if tc.shouldModify {
|
||||||
require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
|
require.Equal(t, 1, len(specMock.ModifyCalls()), "%d: %v", i, tc)
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
|
require.Equal(t, 0, len(specMock.ModifyCalls()), "%d: %v", i, tc)
|
||||||
}
|
}
|
||||||
|
|
||||||
writeExpected := tc.shouldModify && tc.modifyError == nil
|
writeExpected := tc.shouldModify && tc.modifyError == nil
|
||||||
if writeExpected {
|
if writeExpected {
|
||||||
require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc)
|
require.Equal(t, 1, len(specMock.FlushCalls()), "%d: %v", i, tc)
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc)
|
require.Equal(t, 0, len(specMock.FlushCalls()), "%d: %v", i, tc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockShim struct {
|
|
||||||
called bool
|
|
||||||
args []string
|
|
||||||
returnError error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockShim) Exec(args []string) error {
|
|
||||||
m.called = true
|
|
||||||
m.args = args
|
|
||||||
return m.returnError
|
|
||||||
}
|
|
||||||
|
@ -23,52 +23,27 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ociSpecFileName = "config.json"
|
|
||||||
dockerRuncExecutableName = "docker-runc"
|
dockerRuncExecutableName = "docker-runc"
|
||||||
runcExecutableName = "runc"
|
runcExecutableName = "runc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newRuntime is a factory method that constructs a runtime based on the selected configuration.
|
// newRuntime is a factory method that constructs a runtime based on the selected configuration.
|
||||||
func newRuntime(argv []string) (oci.Runtime, error) {
|
func newRuntime(argv []string) (oci.Runtime, error) {
|
||||||
ociSpec, err := newOCISpec(argv)
|
ociSpec, err := oci.NewSpec(logger.Logger, argv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
|
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
runc, err := newRuncRuntime()
|
lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName}
|
||||||
|
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger.Logger, lowLevelRuntimeCandidates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error constructing runc runtime: %v", err)
|
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec)
|
r, err := newNvidiaContainerRuntime(logger.Logger, lowLevelRuntime, ociSpec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err)
|
return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newOCISpec constructs an OCI spec for the provided arguments
|
|
||||||
func newOCISpec(argv []string) (oci.Spec, error) {
|
|
||||||
bundleDir, err := oci.GetBundleDir(argv)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
|
|
||||||
}
|
|
||||||
logger.Infof("Using bundle directory: %v", bundleDir)
|
|
||||||
|
|
||||||
ociSpecPath := oci.GetSpecFilePath(bundleDir)
|
|
||||||
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
|
|
||||||
|
|
||||||
ociSpec := oci.NewSpecFromFile(ociSpecPath)
|
|
||||||
|
|
||||||
return ociSpec, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
|
|
||||||
func newRuncRuntime() (oci.Runtime, error) {
|
|
||||||
return oci.NewLowLevelRuntimeWithLogger(
|
|
||||||
logger.Logger,
|
|
||||||
dockerRuncExecutableName,
|
|
||||||
runcExecutableName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
@ -25,19 +25,14 @@ import (
|
|||||||
|
|
||||||
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
|
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
|
||||||
// The executable specified is taken from the list of supplied candidates, with the first match
|
// The executable specified is taken from the list of supplied candidates, with the first match
|
||||||
// present in the PATH being selected.
|
// present in the PATH being selected. A logger is also specified.
|
||||||
func NewLowLevelRuntime(candidates ...string) (Runtime, error) {
|
func NewLowLevelRuntime(logger *log.Logger, candidates []string) (Runtime, error) {
|
||||||
return NewLowLevelRuntimeWithLogger(log.StandardLogger(), candidates...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLowLevelRuntimeWithLogger creates a Runtime as with NewLowLevelRuntime using the specified logger.
|
|
||||||
func NewLowLevelRuntimeWithLogger(logger *log.Logger, candidates ...string) (Runtime, error) {
|
|
||||||
runtimePath, err := findRuntime(logger, candidates)
|
runtimePath, err := findRuntime(logger, candidates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error locating runtime: %v", err)
|
return nil, fmt.Errorf("error locating runtime: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewRuntimeForPathWithLogger(logger, runtimePath)
|
return NewRuntimeForPath(logger, runtimePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.
|
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.
|
||||||
|
@ -34,13 +34,8 @@ type pathRuntime struct {
|
|||||||
|
|
||||||
var _ Runtime = (*pathRuntime)(nil)
|
var _ Runtime = (*pathRuntime)(nil)
|
||||||
|
|
||||||
// NewRuntimeForPath creates a Runtime for the specified path with the standard logger
|
// NewRuntimeForPath creates a Runtime for the specified logger and path
|
||||||
func NewRuntimeForPath(path string) (Runtime, error) {
|
func NewRuntimeForPath(logger *log.Logger, path string) (Runtime, error) {
|
||||||
return NewRuntimeForPathWithLogger(log.StandardLogger(), path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRuntimeForPathWithLogger creates a Runtime for the specified logger and path
|
|
||||||
func NewRuntimeForPathWithLogger(logger *log.Logger, path string) (Runtime, error) {
|
|
||||||
info, err := os.Stat(path)
|
info, err := os.Stat(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
|
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
|
||||||
|
@ -24,19 +24,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPathRuntimeConstructor(t *testing.T) {
|
func TestPathRuntimeConstructor(t *testing.T) {
|
||||||
r, err := NewRuntimeForPath("////an/invalid/path")
|
logger, _ := testlog.NewNullLogger()
|
||||||
|
|
||||||
|
r, err := NewRuntimeForPath(logger, "////an/invalid/path")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, r)
|
require.Nil(t, r)
|
||||||
|
|
||||||
r, err = NewRuntimeForPath("/tmp")
|
r, err = NewRuntimeForPath(logger, "/tmp")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, r)
|
require.Nil(t, r)
|
||||||
|
|
||||||
r, err = NewRuntimeForPath("/dev/null")
|
r, err = NewRuntimeForPath(logger, "/dev/null")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, r)
|
require.Nil(t, r)
|
||||||
|
|
||||||
r, err = NewRuntimeForPath("/bin/sh")
|
r, err = NewRuntimeForPath(logger, "/bin/sh")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
f, ok := r.(*pathRuntime)
|
f, ok := r.(*pathRuntime)
|
||||||
|
@ -17,15 +17,20 @@
|
|||||||
package oci
|
package oci
|
||||||
|
|
||||||
import (
|
import (
|
||||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/opencontainers/runtime-spec/specs-go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SpecModifier is a function that accepts a pointer to an OCI Srec and returns an
|
// SpecModifier defines an interace for modifying a (raw) OCI spec
|
||||||
// error. The intention is that the function would modify the spec in-place.
|
type SpecModifier interface {
|
||||||
type SpecModifier func(*oci.Spec) error
|
// Modify is a method that accepts a pointer to an OCI Srec and returns an
|
||||||
|
// error. The intention is that the function would modify the spec in-place.
|
||||||
|
Modify(*specs.Spec) error
|
||||||
|
}
|
||||||
|
|
||||||
//go:generate moq -stub -out spec_mock.go . Spec
|
//go:generate moq -stub -out spec_mock.go . Spec
|
||||||
|
|
||||||
// Spec defines the operations to be performed on an OCI specification
|
// Spec defines the operations to be performed on an OCI specification
|
||||||
type Spec interface {
|
type Spec interface {
|
||||||
Load() error
|
Load() error
|
||||||
@ -33,3 +38,20 @@ type Spec interface {
|
|||||||
Modify(SpecModifier) error
|
Modify(SpecModifier) error
|
||||||
LookupEnv(string) (string, bool)
|
LookupEnv(string) (string, bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewSpec creates fileSpec based on the command line arguments passed to the
|
||||||
|
// application using the specified logger.
|
||||||
|
func NewSpec(logger *logrus.Logger, args []string) (Spec, error) {
|
||||||
|
bundleDir, err := GetBundleDir(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting bundle directory: %v", err)
|
||||||
|
}
|
||||||
|
logger.Infof("Using bundle directory: %v", bundleDir)
|
||||||
|
|
||||||
|
ociSpecPath := GetSpecFilePath(bundleDir)
|
||||||
|
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
|
||||||
|
|
||||||
|
ociSpec := NewFileSpec(ociSpecPath)
|
||||||
|
|
||||||
|
return ociSpec, nil
|
||||||
|
}
|
||||||
|
@ -21,37 +21,21 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
|
|
||||||
oci "github.com/opencontainers/runtime-spec/specs-go"
|
"github.com/opencontainers/runtime-spec/specs-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fileSpec struct {
|
type fileSpec struct {
|
||||||
*oci.Spec
|
memorySpec
|
||||||
path string
|
path string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Spec = (*fileSpec)(nil)
|
var _ Spec = (*fileSpec)(nil)
|
||||||
|
|
||||||
// NewSpecFromArgs creates fileSpec based on the command line arguments passed to the
|
// NewFileSpec creates an object that encapsulates a file-backed OCI spec.
|
||||||
// application
|
|
||||||
func NewSpecFromArgs(args []string) (Spec, string, error) {
|
|
||||||
bundleDir, err := GetBundleDir(args)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", fmt.Errorf("error getting bundle directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ociSpecPath := GetSpecFilePath(bundleDir)
|
|
||||||
|
|
||||||
ociSpec := NewSpecFromFile(ociSpecPath)
|
|
||||||
|
|
||||||
return ociSpec, bundleDir, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSpecFromFile creates an object that encapsulates a file-backed OCI spec.
|
|
||||||
// This can be used to read from the file, modify the spec, and write to the
|
// This can be used to read from the file, modify the spec, and write to the
|
||||||
// same file.
|
// same file.
|
||||||
func NewSpecFromFile(filepath string) Spec {
|
func NewFileSpec(filepath string) Spec {
|
||||||
oci := fileSpec{
|
oci := fileSpec{
|
||||||
path: filepath,
|
path: filepath,
|
||||||
}
|
}
|
||||||
@ -68,29 +52,31 @@ func (s *fileSpec) Load() error {
|
|||||||
}
|
}
|
||||||
defer specFile.Close()
|
defer specFile.Close()
|
||||||
|
|
||||||
return s.loadFrom(specFile)
|
spec, err := loadFrom(specFile)
|
||||||
}
|
|
||||||
|
|
||||||
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
|
|
||||||
func (s *fileSpec) loadFrom(reader io.Reader) error {
|
|
||||||
decoder := json.NewDecoder(reader)
|
|
||||||
|
|
||||||
var spec oci.Spec
|
|
||||||
err := decoder.Decode(&spec)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error reading OCI specification: %v", err)
|
return fmt.Errorf("error loading OCI specification from file: %v", err)
|
||||||
}
|
}
|
||||||
|
s.Spec = spec
|
||||||
s.Spec = &spec
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Modify applies the specified SpecModifier to the stored OCI specification.
|
// loadFrom reads the contents of the OCI spec from the specified io.Reader.
|
||||||
func (s *fileSpec) Modify(f SpecModifier) error {
|
func loadFrom(reader io.Reader) (*specs.Spec, error) {
|
||||||
if s.Spec == nil {
|
decoder := json.NewDecoder(reader)
|
||||||
return fmt.Errorf("no spec loaded for modification")
|
|
||||||
|
var spec specs.Spec
|
||||||
|
|
||||||
|
err := decoder.Decode(&spec)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error reading OCI specification: %v", err)
|
||||||
}
|
}
|
||||||
return f(s.Spec)
|
|
||||||
|
return &spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify applies the specified SpecModifier to the stored OCI specification.
|
||||||
|
func (s *fileSpec) Modify(m SpecModifier) error {
|
||||||
|
return s.memorySpec.Modify(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush writes the stored OCI specification to the filepath specifed by the path member.
|
// Flush writes the stored OCI specification to the filepath specifed by the path member.
|
||||||
@ -106,48 +92,20 @@ func (s fileSpec) Flush() error {
|
|||||||
}
|
}
|
||||||
defer specFile.Close()
|
defer specFile.Close()
|
||||||
|
|
||||||
return s.flushTo(specFile)
|
return flushTo(s.Spec, specFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// flushTo writes the stored OCI specification to the specified io.Writer.
|
// flushTo writes the stored OCI specification to the specified io.Writer.
|
||||||
func (s fileSpec) flushTo(writer io.Writer) error {
|
func flushTo(spec *specs.Spec, writer io.Writer) error {
|
||||||
if s.Spec == nil {
|
if spec == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
encoder := json.NewEncoder(writer)
|
encoder := json.NewEncoder(writer)
|
||||||
|
|
||||||
err := encoder.Encode(s.Spec)
|
err := encoder.Encode(spec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error writing OCI specification: %v", err)
|
return fmt.Errorf("error writing OCI specification: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
|
|
||||||
// retrieves the value of the environment variable named
|
|
||||||
// by the key. If the variable is present in the environment the
|
|
||||||
// value (which may be empty) is returned and the boolean is true.
|
|
||||||
// Otherwise the returned value will be empty and the boolean will
|
|
||||||
// be false.
|
|
||||||
func (s fileSpec) LookupEnv(key string) (string, bool) {
|
|
||||||
if s.Spec == nil || s.Spec.Process == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, env := range s.Spec.Process.Env {
|
|
||||||
if !strings.HasPrefix(env, key) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.SplitN(env, "=", 2)
|
|
||||||
if parts[0] == key {
|
|
||||||
if len(parts) < 2 {
|
|
||||||
return "", true
|
|
||||||
}
|
|
||||||
return parts[1], true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
@ -25,111 +25,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLookupEnv(t *testing.T) {
|
|
||||||
const envName = "TEST_ENV"
|
|
||||||
testCases := []struct {
|
|
||||||
spec *specs.Spec
|
|
||||||
expectedValue string
|
|
||||||
expectedExits bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
// nil spec
|
|
||||||
spec: nil,
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// nil process
|
|
||||||
spec: &specs.Spec{},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// nil env
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{},
|
|
||||||
},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// empty env
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{}},
|
|
||||||
},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// different env set
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
|
|
||||||
},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// same prefix
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
|
|
||||||
},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// same suffix
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
|
|
||||||
},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// set blank
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{"TEST_ENV="}},
|
|
||||||
},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// set no-equals
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{"TEST_ENV"}},
|
|
||||||
},
|
|
||||||
expectedValue: "",
|
|
||||||
expectedExits: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// set value
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
|
|
||||||
},
|
|
||||||
expectedValue: "something",
|
|
||||||
expectedExits: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// set with equals
|
|
||||||
spec: &specs.Spec{
|
|
||||||
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
|
|
||||||
},
|
|
||||||
expectedValue: "something=somethingelse",
|
|
||||||
expectedExits: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range testCases {
|
|
||||||
spec := fileSpec{
|
|
||||||
Spec: tc.spec,
|
|
||||||
}
|
|
||||||
|
|
||||||
value, exists := spec.LookupEnv(envName)
|
|
||||||
|
|
||||||
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
|
|
||||||
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadFrom(t *testing.T) {
|
func TestLoadFrom(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
contents []byte
|
contents []byte
|
||||||
@ -148,8 +43,8 @@ func TestLoadFrom(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
spec := fileSpec{}
|
var spec *specs.Spec
|
||||||
err := spec.loadFrom(bytes.NewReader(tc.contents))
|
spec, err := loadFrom(bytes.NewReader(tc.contents))
|
||||||
|
|
||||||
if tc.isError {
|
if tc.isError {
|
||||||
require.Error(t, err, "%d: %v", i, tc)
|
require.Error(t, err, "%d: %v", i, tc)
|
||||||
@ -158,9 +53,9 @@ func TestLoadFrom(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tc.spec == nil {
|
if tc.spec == nil {
|
||||||
require.Nil(t, spec.Spec, "%d: %v", i, tc)
|
require.Nil(t, spec, "%d: %v", i, tc)
|
||||||
} else {
|
} else {
|
||||||
require.EqualValues(t, tc.spec, spec.Spec, "%d: %v", i, tc)
|
require.EqualValues(t, tc.spec, spec, "%d: %v", i, tc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -183,8 +78,7 @@ func TestFlushTo(t *testing.T) {
|
|||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
buffer := bytes.Buffer{}
|
buffer := bytes.Buffer{}
|
||||||
|
|
||||||
spec := fileSpec{Spec: tc.spec}
|
err := flushTo(tc.spec, &buffer)
|
||||||
err := spec.flushTo(&buffer)
|
|
||||||
|
|
||||||
if tc.isError {
|
if tc.isError {
|
||||||
require.Error(t, err, "%d: %v", i, tc)
|
require.Error(t, err, "%d: %v", i, tc)
|
||||||
@ -196,53 +90,10 @@ func TestFlushTo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add a simple test for a writer that returns an error when writing
|
// Add a simple test for a writer that returns an error when writing
|
||||||
spec := fileSpec{Spec: &specs.Spec{}}
|
err := flushTo(&specs.Spec{}, errorWriter{})
|
||||||
err := spec.flushTo(errorWriter{})
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModify(t *testing.T) {
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
spec *specs.Spec
|
|
||||||
modifierError error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
spec: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
spec: &specs.Spec{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
spec: &specs.Spec{},
|
|
||||||
modifierError: fmt.Errorf("error in modifier"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range testCases {
|
|
||||||
spec := fileSpec{Spec: tc.spec}
|
|
||||||
|
|
||||||
modifier := func(spec *specs.Spec) error {
|
|
||||||
if tc.modifierError == nil {
|
|
||||||
spec.Version = "updated"
|
|
||||||
}
|
|
||||||
return tc.modifierError
|
|
||||||
}
|
|
||||||
|
|
||||||
err := spec.Modify(modifier)
|
|
||||||
|
|
||||||
if tc.spec == nil {
|
|
||||||
require.Error(t, err, "%d: %v", i, tc)
|
|
||||||
} else if tc.modifierError != nil {
|
|
||||||
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
|
|
||||||
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err, "%d: %v", i, tc)
|
|
||||||
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// errorWriter implements the io.Writer interface, always returning an error when
|
// errorWriter implements the io.Writer interface, always returning an error when
|
||||||
// writing.
|
// writing.
|
||||||
type errorWriter struct{}
|
type errorWriter struct{}
|
||||||
|
83
internal/oci/spec_memory.go
Normal file
83
internal/oci/spec_memory.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/**
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
**/
|
||||||
|
|
||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/opencontainers/runtime-spec/specs-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memorySpec struct {
|
||||||
|
*specs.Spec
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemorySpec creates a Spec instance from the specified OCI spec
|
||||||
|
func NewMemorySpec(spec *specs.Spec) Spec {
|
||||||
|
s := memorySpec{
|
||||||
|
Spec: spec,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load is a no-op for the memorySpec spec
|
||||||
|
func (s *memorySpec) Load() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush is a no-op for the memorySpec spec
|
||||||
|
func (s *memorySpec) Flush() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify applies the specified SpecModifier to the stored OCI specification.
|
||||||
|
func (s *memorySpec) Modify(m SpecModifier) error {
|
||||||
|
if s.Spec == nil {
|
||||||
|
return fmt.Errorf("cannot modify nil spec")
|
||||||
|
}
|
||||||
|
return m.Modify(s.Spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
|
||||||
|
// retrieves the value of the environment variable named
|
||||||
|
// by the key. If the variable is present in the environment the
|
||||||
|
// value (which may be empty) is returned and the boolean is true.
|
||||||
|
// Otherwise the returned value will be empty and the boolean will
|
||||||
|
// be false.
|
||||||
|
func (s memorySpec) LookupEnv(key string) (string, bool) {
|
||||||
|
if s.Spec == nil || s.Spec.Process == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, env := range s.Spec.Process.Env {
|
||||||
|
if !strings.HasPrefix(env, key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(env, "=", 2)
|
||||||
|
if parts[0] == key {
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "", true
|
||||||
|
}
|
||||||
|
return parts[1], true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
178
internal/oci/spec_memory_test.go
Normal file
178
internal/oci/spec_memory_test.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
/**
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
**/
|
||||||
|
|
||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/opencontainers/runtime-spec/specs-go"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLookupEnv(t *testing.T) {
|
||||||
|
const envName = "TEST_ENV"
|
||||||
|
testCases := []struct {
|
||||||
|
spec *specs.Spec
|
||||||
|
expectedValue string
|
||||||
|
expectedExits bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// nil spec
|
||||||
|
spec: nil,
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// nil process
|
||||||
|
spec: &specs.Spec{},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// nil env
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{},
|
||||||
|
},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// empty env
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{}},
|
||||||
|
},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// different env set
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{"SOMETHING_ELSE=foo"}},
|
||||||
|
},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// same prefix
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{"TEST_ENV_BUT_NOT=foo"}},
|
||||||
|
},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// same suffix
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{"NOT_TEST_ENV=foo"}},
|
||||||
|
},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// set blank
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{"TEST_ENV="}},
|
||||||
|
},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// set no-equals
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{"TEST_ENV"}},
|
||||||
|
},
|
||||||
|
expectedValue: "",
|
||||||
|
expectedExits: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// set value
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{"TEST_ENV=something"}},
|
||||||
|
},
|
||||||
|
expectedValue: "something",
|
||||||
|
expectedExits: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// set with equals
|
||||||
|
spec: &specs.Spec{
|
||||||
|
Process: &specs.Process{Env: []string{"TEST_ENV=something=somethingelse"}},
|
||||||
|
},
|
||||||
|
expectedValue: "something=somethingelse",
|
||||||
|
expectedExits: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
spec := memorySpec{
|
||||||
|
Spec: tc.spec,
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists := spec.LookupEnv(envName)
|
||||||
|
|
||||||
|
require.Equal(t, tc.expectedValue, value, "%d: %v", i, tc)
|
||||||
|
require.Equal(t, tc.expectedExits, exists, "%d: %v", i, tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModify(t *testing.T) {
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
spec *specs.Spec
|
||||||
|
modifierError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
spec: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
spec: &specs.Spec{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
spec: &specs.Spec{},
|
||||||
|
modifierError: fmt.Errorf("error in modifier"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
spec := NewMemorySpec(tc.spec).(*memorySpec)
|
||||||
|
|
||||||
|
err := spec.Modify(modifier{tc.modifierError})
|
||||||
|
|
||||||
|
if tc.spec == nil {
|
||||||
|
require.Error(t, err, "%d: %v", i, tc)
|
||||||
|
} else if tc.modifierError != nil {
|
||||||
|
require.EqualError(t, err, tc.modifierError.Error(), "%d: %v", i, tc)
|
||||||
|
require.EqualValues(t, &specs.Spec{}, spec.Spec, "%d: %v", i, tc)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err, "%d: %v", i, tc)
|
||||||
|
require.Equal(t, "updated", spec.Spec.Version, "%d: %v", i, tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Ideally we would generated a mock for the SpecModifer too. This causes
|
||||||
|
// an import cycle and we define a local type as a work-around.
|
||||||
|
type modifier struct {
|
||||||
|
modifierError error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m modifier) Modify(spec *specs.Spec) error {
|
||||||
|
if m.modifierError == nil {
|
||||||
|
spec.Version = "updated"
|
||||||
|
}
|
||||||
|
return m.modifierError
|
||||||
|
}
|
@ -20,7 +20,7 @@ func TestMaintainSpec(t *testing.T) {
|
|||||||
for _, f := range files {
|
for _, f := range files {
|
||||||
inputSpecPath := filepath.Join(moduleRoot, "test/input", f)
|
inputSpecPath := filepath.Join(moduleRoot, "test/input", f)
|
||||||
|
|
||||||
spec := NewSpecFromFile(inputSpecPath).(*fileSpec)
|
spec := NewFileSpec(inputSpecPath).(*fileSpec)
|
||||||
|
|
||||||
spec.Load()
|
spec.Load()
|
||||||
|
|
||||||
|
81
internal/runtime/runtime_modifier.go
Normal file
81
internal/runtime/runtime_modifier.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/*
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package runtime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type modifyingRuntimeWrapper struct {
|
||||||
|
logger *log.Logger
|
||||||
|
runtime oci.Runtime
|
||||||
|
ociSpec oci.Spec
|
||||||
|
modifier oci.SpecModifier
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ oci.Runtime = (*modifyingRuntimeWrapper)(nil)
|
||||||
|
|
||||||
|
// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
|
||||||
|
// before invoking the wrapped runtime.
|
||||||
|
func NewModifyingRuntimeWrapper(logger *log.Logger, runtime oci.Runtime, spec oci.Spec, modifier oci.SpecModifier) oci.Runtime {
|
||||||
|
rt := modifyingRuntimeWrapper{
|
||||||
|
logger: logger,
|
||||||
|
runtime: runtime,
|
||||||
|
ociSpec: spec,
|
||||||
|
modifier: modifier,
|
||||||
|
}
|
||||||
|
return &rt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing
|
||||||
|
// into the wrapped runtime.
|
||||||
|
func (r *modifyingRuntimeWrapper) Exec(args []string) error {
|
||||||
|
if oci.HasCreateSubcommand(args) {
|
||||||
|
err := r.modify()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not apply required modification to OCI specification: %v", err)
|
||||||
|
}
|
||||||
|
r.logger.Infof("Applied required modification to OCI specification")
|
||||||
|
} else {
|
||||||
|
r.logger.Infof("No modification of OCI specification required")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Infof("Forwarding command to runtime")
|
||||||
|
return r.runtime.Exec(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// modify loads, modifies, and flushes the OCI specification using the defined Modifier
|
||||||
|
func (r *modifyingRuntimeWrapper) modify() error {
|
||||||
|
err := r.ociSpec.Load()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error loading OCI specification for modification: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.ociSpec.Modify(r.modifier)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error modifying OCI spec: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.ociSpec.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error writing modified OCI specification: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
116
internal/runtime/runtime_modifier_test.go
Normal file
116
internal/runtime/runtime_modifier_test.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
/*
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package runtime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||||
|
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExec(t *testing.T) {
|
||||||
|
logger, hook := testlog.NewNullLogger()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
description string
|
||||||
|
shouldModify bool
|
||||||
|
shouldFlush bool
|
||||||
|
shouldForward bool
|
||||||
|
args []string
|
||||||
|
modifyError error
|
||||||
|
writeError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
description: "no args forwards",
|
||||||
|
shouldModify: false,
|
||||||
|
shouldFlush: false,
|
||||||
|
shouldForward: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "create modifies",
|
||||||
|
args: []string{"create"},
|
||||||
|
shouldModify: true,
|
||||||
|
shouldFlush: true,
|
||||||
|
shouldForward: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "modify error does not write or forward",
|
||||||
|
args: []string{"create"},
|
||||||
|
modifyError: fmt.Errorf("error modifying"),
|
||||||
|
shouldModify: true,
|
||||||
|
shouldFlush: false,
|
||||||
|
shouldForward: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "write error does not forward",
|
||||||
|
args: []string{"create"},
|
||||||
|
writeError: fmt.Errorf("error writing"),
|
||||||
|
shouldModify: true,
|
||||||
|
shouldFlush: true,
|
||||||
|
shouldForward: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
hook.Reset()
|
||||||
|
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
runtimeMock := &oci.RuntimeMock{}
|
||||||
|
specMock := &oci.SpecMock{
|
||||||
|
ModifyFunc: func(specModifier oci.SpecModifier) error {
|
||||||
|
return tc.modifyError
|
||||||
|
},
|
||||||
|
FlushFunc: func() error {
|
||||||
|
return tc.writeError
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
shim := NewModifyingRuntimeWrapper(
|
||||||
|
logger,
|
||||||
|
runtimeMock,
|
||||||
|
specMock,
|
||||||
|
// TODO: We should test the interactions with the SpecModifier too
|
||||||
|
nil)
|
||||||
|
|
||||||
|
err := shim.Exec(tc.args)
|
||||||
|
if tc.modifyError != nil || tc.writeError != nil {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.shouldModify {
|
||||||
|
require.Equal(t, 1, len(specMock.ModifyCalls()))
|
||||||
|
} else {
|
||||||
|
require.Equal(t, 0, len(specMock.ModifyCalls()))
|
||||||
|
}
|
||||||
|
if tc.shouldFlush {
|
||||||
|
require.Equal(t, 1, len(specMock.FlushCalls()))
|
||||||
|
} else {
|
||||||
|
require.Equal(t, 0, len(specMock.FlushCalls()))
|
||||||
|
}
|
||||||
|
if tc.shouldForward {
|
||||||
|
require.Equal(t, 1, len(runtimeMock.ExecCalls()))
|
||||||
|
} else {
|
||||||
|
require.Equal(t, 0, len(runtimeMock.ExecCalls()))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user