Move modifier code for inserting nvidia-container-runtime-hook to separate package

This change moves the code defining the insertion of the nvidia-container-runtime
hook to a separate package. This allows for better distinction between the existing
and experimental modifications.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2022-02-22 16:09:45 +02:00
parent 4177fddcc4
commit c6dfc1027d
7 changed files with 233 additions and 231 deletions

View File

@ -11,8 +11,6 @@ import (
const (
configOverride = "XDG_CONFIG_HOME"
configFilePath = "nvidia-container-runtime/config.toml"
hookDefaultFilePath = "/usr/bin/nvidia-container-runtime-hook"
)
var (
@ -49,7 +47,7 @@ func run(argv []string) (rerr error) {
logger.CloseFile()
}()
r, err := newRuntime(argv)
r, err := newNVIDIAContainerRuntime(logger.Logger, cfg, argv)
if err != nil {
return fmt.Errorf("error creating runtime: %v", err)
}

View File

@ -10,6 +10,7 @@ import (
"strings"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-container-runtime/modifier"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require"
@ -170,7 +171,7 @@ func TestDuplicateHook(t *testing.T) {
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
// testing.
func addNVIDIAHook(spec *specs.Spec) error {
m := addHookModifier{logger: logger.Logger}
m := modifier.NewStableRuntimeModifier(logger.Logger)
return m.Modify(spec)
}

View File

@ -14,7 +14,7 @@
# limitations under the License.
*/
package main
package modifier
import (
"os"
@ -22,36 +22,31 @@ import (
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/runtime"
"github.com/opencontainers/runtime-spec/specs-go"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)
// newNvidiaContainerRuntime is a constructor for a standard runtime shim. This uses
// a ModifyingRuntimeWrapper to apply the required modifications before execing to the
// specified low-level runtime
func newNvidiaContainerRuntime(logger *log.Logger, lowlevelRuntime oci.Runtime, ociSpec oci.Spec) (oci.Runtime, error) {
modifier := addHookModifier{logger: logger}
const (
hookDefaultFilePath = "/usr/bin/nvidia-container-runtime-hook"
)
r := runtime.NewModifyingRuntimeWrapper(
logger,
lowlevelRuntime,
ociSpec,
modifier,
)
// NewStableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI
// spec. The specified logger is used to capture log output.
func NewStableRuntimeModifier(logger *logrus.Logger) oci.SpecModifier {
m := stableRuntimeModifier{logger: logger}
return r, nil
return &m
}
// addHookModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
// stableRuntimeModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a
// prestart hook. If the hook is already present, no modification is made.
type addHookModifier struct {
logger *log.Logger
type stableRuntimeModifier struct {
logger *logrus.Logger
}
// Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook
// as a prestart hook.
func (m addHookModifier) Modify(spec *specs.Spec) error {
func (m stableRuntimeModifier) Modify(spec *specs.Spec) error {
path, err := exec.LookPath("nvidia-container-runtime-hook")
if err != nil {
path = hookDefaultFilePath
@ -61,18 +56,17 @@ func (m addHookModifier) Modify(spec *specs.Spec) error {
}
}
m.logger.Printf("prestart hook path: %s\n", path)
m.logger.Infof("Using prestart hook path: %s", path)
args := []string{path}
if spec.Hooks == nil {
spec.Hooks = &specs.Hooks{}
} else if len(spec.Hooks.Prestart) != 0 {
for _, hook := range spec.Hooks.Prestart {
if !strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
continue
if strings.Contains(hook.Path, "nvidia-container-runtime-hook") {
m.logger.Infof("existing nvidia prestart hook found in OCI spec")
return nil
}
m.logger.Println("existing nvidia prestart hook in OCI spec file")
return nil
}
}

View File

@ -0,0 +1,170 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package modifier
import (
"os"
"path/filepath"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
type testConfig struct {
root string
binPath string
}
var cfg *testConfig
func TestMain(m *testing.M) {
// TEST SETUP
// Determine the module root and the test binary path
var err error
moduleRoot, err := test.GetModuleRoot()
if err != nil {
logrus.Fatalf("error in test setup: could not get module root: %v", err)
}
testBinPath := filepath.Join(moduleRoot, "test", "bin")
// Set the environment variables for the test
os.Setenv("PATH", test.PrependToPath(testBinPath, moduleRoot))
// Store the root and binary paths in the test Config
cfg = &testConfig{
root: moduleRoot,
binPath: testBinPath,
}
// RUN TESTS
exitCode := m.Run()
os.Exit(exitCode)
}
func TestAddHookModifier(t *testing.T) {
logger, logHook := testlog.NewNullLogger()
testHookPath := filepath.Join(cfg.binPath, "nvidia-container-runtime-hook")
testCases := []struct {
description string
spec specs.Spec
expectedError error
expectedSpec specs.Spec
}{
{
description: "empty spec adds hook",
spec: specs.Spec{},
expectedSpec: specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{
{
Path: testHookPath,
Args: []string{testHookPath, "prestart"},
},
},
},
},
},
{
description: "spec with empty hooks adds hook",
spec: specs.Spec{
Hooks: &specs.Hooks{},
},
expectedSpec: specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{
{
Path: testHookPath,
Args: []string{testHookPath, "prestart"},
},
},
},
},
},
{
description: "hook is not replaced",
spec: specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{
{
Path: "nvidia-container-runtime-hook",
},
},
},
},
expectedSpec: specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{
{
Path: "nvidia-container-runtime-hook",
},
},
},
},
},
{
description: "other hooks are not replaced",
spec: specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{
{
Path: "some-hook",
},
},
},
},
expectedSpec: specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{
{
Path: "some-hook",
},
{
Path: testHookPath,
Args: []string{testHookPath, "prestart"},
},
},
},
},
},
}
for _, tc := range testCases {
logHook.Reset()
t.Run(tc.description, func(t *testing.T) {
m := NewStableRuntimeModifier(logger)
err := m.Modify(&tc.spec)
if tc.expectedError != nil {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.EqualValues(t, tc.expectedSpec, tc.spec)
})
}
}

View File

@ -1,190 +0,0 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package main
import (
"fmt"
"strings"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/opencontainers/runtime-spec/specs-go"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestAddNvidiaHook(t *testing.T) {
logger, logHook := testlog.NewNullLogger()
mockRuntime := &oci.RuntimeMock{}
testCases := []struct {
spec *specs.Spec
errorPrefix string
shouldNotAdd bool
}{
{
spec: &specs.Spec{},
},
{
spec: &specs.Spec{
Hooks: &specs.Hooks{},
},
},
{
spec: &specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{{
Path: "some-hook",
}},
},
},
},
{
spec: &specs.Spec{
Hooks: &specs.Hooks{
Prestart: []specs.Hook{{
Path: "nvidia-container-runtime-hook",
}},
},
},
shouldNotAdd: true,
},
}
for i, tc := range testCases {
logHook.Reset()
var numPrestartHooks int
if tc.spec.Hooks != nil {
numPrestartHooks = len(tc.spec.Hooks.Prestart)
}
shim, err := newNvidiaContainerRuntime(
logger,
mockRuntime,
oci.NewMemorySpec(tc.spec),
)
require.NoError(t, err)
err = shim.Exec([]string{"runtime", "create"})
require.NoError(t, err)
if tc.errorPrefix == "" {
require.NoErrorf(t, err, "%d: %v", i, tc)
} else {
require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc)
require.NotNilf(t, tc.spec.Hooks, "%d: %v", i, tc)
require.Equalf(t, 1, nvidiaHookCount(tc.spec.Hooks), "%d: %v", i, tc)
if tc.shouldNotAdd {
require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc)
} else {
require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc)
nvidiaHook := tc.spec.Hooks.Poststart[len(tc.spec.Hooks.Poststart)-1]
// TODO: This assumes that the hook has been set up in the makefile
expectedPath := "/usr/bin/nvidia-container-runtime-hook"
require.Equalf(t, expectedPath, nvidiaHook.Path, "%d: %v", i, tc)
require.Equalf(t, []string{expectedPath, "prestart"}, nvidiaHook.Args, "%d: %v", i, tc)
require.Emptyf(t, nvidiaHook.Env, "%d: %v", i, tc)
require.Nilf(t, nvidiaHook.Timeout, "%d: %v", i, tc)
}
}
}
}
func TestNvidiaContainerRuntime(t *testing.T) {
logger, hook := testlog.NewNullLogger()
mockRuntime := &oci.RuntimeMock{}
testCases := []struct {
shouldModify bool
args []string
modifyError error
writeError error
}{
{
shouldModify: false,
},
{
args: []string{"create"},
shouldModify: true,
},
{
args: []string{"--bundle=create"},
shouldModify: false,
},
{
args: []string{"--bundle", "create"},
shouldModify: false,
},
{
args: []string{"create"},
shouldModify: true,
},
{
args: []string{"create"},
modifyError: fmt.Errorf("error modifying"),
shouldModify: true,
},
{
args: []string{"create"},
writeError: fmt.Errorf("error writing"),
shouldModify: true,
},
}
for i, tc := range testCases {
hook.Reset()
specMock := &oci.SpecMock{
ModifyFunc: func(specModifier oci.SpecModifier) error {
return tc.modifyError
},
FlushFunc: func() error {
return tc.writeError
},
}
shim, err := newNvidiaContainerRuntime(logger, mockRuntime, specMock)
require.NoError(t, err)
err = shim.Exec(tc.args)
if tc.modifyError != nil || tc.writeError != nil {
require.Error(t, err, "%d: %v", i, tc)
} else {
require.NoError(t, err, "%d: %v", i, tc)
}
if tc.shouldModify {
require.Equal(t, 1, len(specMock.ModifyCalls()), "%d: %v", i, tc)
} else {
require.Equal(t, 0, len(specMock.ModifyCalls()), "%d: %v", i, tc)
}
writeExpected := tc.shouldModify && tc.modifyError == nil
if writeExpected {
require.Equal(t, 1, len(specMock.FlushCalls()), "%d: %v", i, tc)
} else {
require.Equal(t, 0, len(specMock.FlushCalls()), "%d: %v", i, tc)
}
}
}

View File

@ -1,5 +1,5 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,7 +19,10 @@ package main
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-container-runtime/modifier"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/runtime"
"github.com/sirupsen/logrus"
)
const (
@ -27,23 +30,28 @@ const (
runcExecutableName = "runc"
)
// newRuntime is a factory method that constructs a runtime based on the selected configuration.
func newRuntime(argv []string) (oci.Runtime, error) {
ociSpec, err := oci.NewSpec(logger.Logger, argv)
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config, argv []string) (oci.Runtime, error) {
ociSpec, err := oci.NewSpec(logger, argv)
if err != nil {
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName}
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger.Logger, lowLevelRuntimeCandidates)
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, lowLevelRuntimeCandidates)
if err != nil {
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
}
r, err := newNvidiaContainerRuntime(logger.Logger, lowLevelRuntime, ociSpec)
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA Container Runtime: %v", err)
}
specModifier := modifier.NewStableRuntimeModifier(logger)
// Create the wrapping runtime with the specified modifier
r := runtime.NewModifyingRuntimeWrapper(
logger,
lowLevelRuntime,
ociSpec,
specModifier,
)
return r, nil
}

View File

@ -19,12 +19,33 @@ package main
import (
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestConstructor(t *testing.T) {
shim, err := newRuntime([]string{})
func TestFactoryMethod(t *testing.T) {
logger, _ := testlog.NewNullLogger()
require.NoError(t, err)
require.NotNil(t, shim)
testCases := []struct {
description string
config config
argv []string
expectedError bool
}{
{
description: "empty config no error",
config: config{},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
_, err := newNVIDIAContainerRuntime(logger, &tc.config, tc.argv)
if tc.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}