Import modifying runtime abstraction from experimental runtime

This change imports the modifying runtime abstraction from the
experimental branch. This encapsulates the checks for whether
modification is required, and forwards the loaded spec to
the specified modifier. This allows for the same code to be
reused when performing more complex modifications.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar
2022-01-31 15:11:36 +01:00
parent bf8c3bab72
commit 4177fddcc4
15 changed files with 584 additions and 388 deletions

View File

@@ -167,11 +167,11 @@ func TestDuplicateHook(t *testing.T) {
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
}
// addNVIDIAHook is a basic wrapper for nvidiaContainerRunime.addNVIDIAHook that is used for
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
// testing.
func addNVIDIAHook(spec *specs.Spec) error {
r := nvidiaContainerRuntime{logger: logger.Logger}
return r.addNVIDIAHook(spec)
m := addHookModifier{logger: logger.Logger}
return m.Modify(spec)
}
func (c testConfig) getRuntimeSpec() (specs.Spec, error) {

View File

@@ -17,88 +17,41 @@
package main
import (
"fmt"
"os"
"os/exec"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/runtime"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
)
// nvidiaContainerRuntime encapsulates the NVIDIA Container Runtime. It wraps the specified runtime, conditionally
// modifying the specified OCI specification before invoking the runtime.
type nvidiaContainerRuntime struct {
logger *log.Logger
runtime oci.Runtime
ociSpec oci.Spec
// newNvidiaContainerRuntime is a constructor for a standard runtime shim. This uses
// a ModifyingRuntimeWrapper to apply the required modifications before execing to the
// specified low-level runtime
func newNvidiaContainerRuntime(logger *log.Logger, lowlevelRuntime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
modifier := addHookModifier{logger: logger}
r := runtime.NewModifyingRuntimeWrapper(
logger,
lowlevelRuntime,
ociSpec,
modifier,
)
return r, nil
}
var _ oci.Runtime = (*nvidiaContainerRuntime)(nil)
// newNvidiaContainerRuntime is a constructor for a standard runtime shim.
func newNvidiaContainerRuntimeWithLogger(logger *log.Logger, runtime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
r := nvidiaContainerRuntime{
logger: logger,
runtime: runtime,
ociSpec: ociSpec,
}
return &r, nil
// addHookModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
// prestart hook. If the hook is already present, no modification is made.
type addHookModifier struct {
logger *log.Logger
}
// Exec defines the entrypoint for the NVIDIA Container Runtime. A check is performed to see whether modifications
// to the OCI spec are required -- and applicable modifcations applied. The supplied arguments are then
// forwarded to the underlying runtime's Exec method.
func (r nvidiaContainerRuntime) Exec(args []string) error {
if r.modificationRequired(args) {
err := r.modifyOCISpec()
if err != nil {
return fmt.Errorf("error modifying OCI spec: %v", err)
}
}
r.logger.Println("Forwarding command to runtime")
return r.runtime.Exec(args)
}
// modificationRequired checks the intput arguments to determine whether a modification
// to the OCI spec is required.
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
if oci.HasCreateSubcommand(args) {
r.logger.Infof("'create' command detected; modification required")
return true
}
r.logger.Infof("No modification required")
return false
}
// modifyOCISpec loads and modifies the OCI spec specified in the nvidiaContainerRuntime
// struct. The spec is modified in-place and written to the same file as the input after
// modifcationas are applied.
func (r nvidiaContainerRuntime) modifyOCISpec() error {
err := r.ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification for modification: %v", err)
}
err = r.ociSpec.Modify(r.addNVIDIAHook)
if err != nil {
return fmt.Errorf("error injecting NVIDIA Container Runtime hook: %v", err)
}
err = r.ociSpec.Flush()
if err != nil {
return fmt.Errorf("error writing modified OCI specification: %v", err)
}
return nil
}
// addNVIDIAHook modifies the specified OCI specification in-place, inserting a
// prestart hook.
func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
// as a prestart hook.
func (m addHookModifier) Modify(spec *specs.Spec) error {
path, err := exec.LookPath("nvidia-container-runtime-hook")
if err != nil {
path = hookDefaultFilePath
@@ -108,7 +61,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
}
}
r.logger.Printf("prestart hook path: %s\n", path)
m.logger.Printf("prestart hook path: %s\n", path)
args := []string{path}
if spec.Hooks == nil {
@@ -118,7 +71,7 @@ func (r nvidiaContainerRuntime) addNVIDIAHook(spec *specs.Spec) error {
if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
continue
}
r.logger.Println("existing nvidia prestart hook in OCI spec file")
m.logger.Println("existing nvidia prestart hook in OCI spec file")
return nil
}
}

View File

@@ -29,9 +29,8 @@ import (
func TestAddNvidiaHook(t *testing.T) {
logger, logHook := testlog.NewNullLogger()
shim := nvidiaContainerRuntime{
logger: logger,
}
mockRuntime := &oci.RuntimeMock{}
testCases := []struct {
spec *specs.Spec
@@ -75,7 +74,16 @@ func TestAddNvidiaHook(t *testing.T) {
numPrestartHooks = len(tc.spec.Hooks.Prestart)
}
err := shim.addNVIDIAHook(tc.spec)
shim, err := newNvidiaContainerRuntime(
logger,
mockRuntime,
oci.NewMemorySpec(tc.spec),
)
require.NoError(t, err)
err = shim.Exec([]string{"runtime", "create"})
require.NoError(t, err)
if tc.errorPrefix == "" {
require.NoErrorf(t, err, "%d: %v", i, tc)
@@ -106,45 +114,39 @@ func TestAddNvidiaHook(t *testing.T) {
func TestNvidiaContainerRuntime(t *testing.T) {
logger, hook := testlog.NewNullLogger()
mockRuntime := &oci.RuntimeMock{}
testCases := []struct {
shim nvidiaContainerRuntime
shouldModify bool
args []string
modifyError error
writeError error
}{
{
shim: nvidiaContainerRuntime{},
shouldModify: false,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
shouldModify: true,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"--bundle=create"},
shouldModify: false,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"--bundle", "create"},
shouldModify: false,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
shouldModify: true,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
modifyError: fmt.Errorf("error modifying"),
shouldModify: true,
},
{
shim: nvidiaContainerRuntime{},
args: []string{"create"},
writeError: fmt.Errorf("error writing"),
shouldModify: true,
@@ -152,10 +154,8 @@ func TestNvidiaContainerRuntime(t *testing.T) {
}
for i, tc := range testCases {
tc.shim.logger = logger
hook.Reset()
ociMock := &oci.SpecMock{
specMock := &oci.SpecMock{
ModifyFunc: func(specModifier oci.SpecModifier) error {
return tc.modifyError
},
@@ -163,12 +163,11 @@ func TestNvidiaContainerRuntime(t *testing.T) {
return tc.writeError
},
}
require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)
tc.shim.ociSpec = ociMock
tc.shim.runtime = &MockShim{}
shim, err := newNvidiaContainerRuntime(logger, mockRuntime, specMock)
require.NoError(t, err)
err := tc.shim.Exec(tc.args)
err = shim.Exec(tc.args)
if tc.modifyError != nil || tc.writeError != nil {
require.Error(t, err, "%d: %v", i, tc)
} else {
@@ -176,28 +175,16 @@ func TestNvidiaContainerRuntime(t *testing.T) {
}
if tc.shouldModify {
require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
require.Equal(t, 1, len(specMock.ModifyCalls()), "%d: %v", i, tc)
} else {
require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
require.Equal(t, 0, len(specMock.ModifyCalls()), "%d: %v", i, tc)
}
writeExpected := tc.shouldModify && tc.modifyError == nil
if writeExpected {
require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc)
require.Equal(t, 1, len(specMock.FlushCalls()), "%d: %v", i, tc)
} else {
require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc)
require.Equal(t, 0, len(specMock.FlushCalls()), "%d: %v", i, tc)
}
}
}
type MockShim struct {
called bool
args []string
returnError error
}
func (m *MockShim) Exec(args []string) error {
m.called = true
m.args = args
return m.returnError
}

View File

@@ -23,52 +23,27 @@ import (
)
const (
ociSpecFileName = "config.json"
dockerRuncExecutableName = "docker-runc"
runcExecutableName = "runc"
)
// newRuntime is a factory method that constructs a runtime based on the selected configuration.
func newRuntime(argv []string) (oci.Runtime, error) {
ociSpec, err := newOCISpec(argv)
ociSpec, err := oci.NewSpec(logger.Logger, argv)
if err != nil {
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
runc, err := newRuncRuntime()
lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName}
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger.Logger, lowLevelRuntimeCandidates)
if err != nil {
return nil, fmt.Errorf("error constructing runc runtime: %v", err)
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
}
r, err := newNvidiaContainerRuntimeWithLogger(logger.Logger, runc, ociSpec)
r, err := newNvidiaContainerRuntime(logger.Logger, lowLevelRuntime, ociSpec)
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err)
}
return r, nil
}
// newOCISpec constructs an OCI spec for the provided arguments
func newOCISpec(argv []string) (oci.Spec, error) {
bundleDir, err := oci.GetBundleDir(argv)
if err != nil {
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
}
logger.Infof("Using bundle directory: %v", bundleDir)
ociSpecPath := oci.GetSpecFilePath(bundleDir)
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
ociSpec := oci.NewSpecFromFile(ociSpecPath)
return ociSpec, nil
}
// newRuncRuntime locates the runc binary and wraps it in a SyscallExecRuntime
func newRuncRuntime() (oci.Runtime, error) {
return oci.NewLowLevelRuntimeWithLogger(
logger.Logger,
dockerRuncExecutableName,
runcExecutableName,
)
}