diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index a6129432..d3dc84e7 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -68,20 +68,10 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image return nil, fmt.Errorf("failed to get CDI spec: %v", err) } - cdiModifier, err := cdi.New( + return cdi.New( cdi.WithLogger(logger), cdi.WithSpec(spec.Raw()), ) - if err != nil { - return nil, fmt.Errorf("failed to construct CDI modifier: %v", err) - } - - modifiers := Merge( - nvidiaContainerRuntimeHookRemover{logger}, - cdiModifier, - ) - - return modifiers, nil } func checkRequirements(logger logger.Interface, image image.CUDA) error { diff --git a/internal/modifier/csv_test.go b/internal/modifier/csv_test.go index 8e5f60b0..cdd41ce7 100644 --- a/internal/modifier/csv_test.go +++ b/internal/modifier/csv_test.go @@ -19,7 +19,6 @@ package modifier import ( "testing" - "github.com/opencontainers/runtime-spec/specs-go" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" @@ -74,66 +73,3 @@ func TestNewCSVModifier(t *testing.T) { }) } } - -func TestCSVModifierRemovesHook(t *testing.T) { - logger, _ := testlog.NewNullLogger() - - testCases := []struct { - description string - spec *specs.Spec - expectedError error - expectedSpec *specs.Spec - }{ - { - description: "modification removes existing nvidia-container-runtime-hook", - spec: &specs.Spec{ - Hooks: &specs.Hooks{ - Prestart: []specs.Hook{ - { - Path: "/path/to/nvidia-container-runtime-hook", - Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"}, - }, - }, - }, - }, - expectedSpec: &specs.Spec{ - Hooks: &specs.Hooks{ - Prestart: []specs.Hook{}, - }, - }, - }, - { - description: "modification removes existing nvidia-container-toolkit", - spec: &specs.Spec{ - Hooks: &specs.Hooks{ - Prestart: []specs.Hook{ - { - Path: "/path/to/nvidia-container-toolkit", - Args: []string{"/path/to/nvidia-container-toolkit", "prestart"}, - }, - }, - }, - }, - expectedSpec: &specs.Spec{ - Hooks: &specs.Hooks{ - Prestart: []specs.Hook{}, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - m := nvidiaContainerRuntimeHookRemover{logger: logger} - - err := m.Modify(tc.spec) - if tc.expectedError != nil { - require.Error(t, err) - } else { - require.NoError(t, err) - } - - require.Empty(t, tc.spec.Hooks.Prestart) - }) - } -} diff --git a/internal/modifier/hook_remover.go b/internal/modifier/hook_remover.go index 059cd8af..24cf7666 100644 --- a/internal/modifier/hook_remover.go +++ b/internal/modifier/hook_remover.go @@ -33,6 +33,13 @@ type nvidiaContainerRuntimeHookRemover struct { var _ oci.SpecModifier = (*nvidiaContainerRuntimeHookRemover)(nil) +// NewNvidiaContainerRuntimeHookRemover creates a modifier that removes any NVIDIA Container Runtime hooks from the provided spec. +func NewNvidiaContainerRuntimeHookRemover(logger logger.Interface) oci.SpecModifier { + return nvidiaContainerRuntimeHookRemover{ + logger: logger, + } +} + // Modify removes any NVIDIA Container Runtime hooks from the provided spec func (m nvidiaContainerRuntimeHookRemover) Modify(spec *specs.Spec) error { if spec == nil { diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index a4426459..2b5cd9c6 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -85,6 +85,8 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp switch modifierType { case "mode": modifiers = append(modifiers, modeModifier) + case "nvidia-hook-remover": + modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger)) case "graphics": graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver) if err != nil { @@ -121,10 +123,10 @@ func supportedModifierTypes(mode string) []string { switch mode { case "cdi": // For CDI mode we make no additional modifications. - return []string{"mode"} + return []string{"nvidia-hook-remover", "mode"} case "csv": // For CSV mode we support mode and feature-gated modification. - return []string{"mode", "feature-gated"} + return []string{"nvidia-hook-remover", "mode", "feature-gated"} default: return []string{"mode", "graphics", "feature-gated"} } diff --git a/internal/runtime/runtime_factory_test.go b/internal/runtime/runtime_factory_test.go index e2a976a2..ae95e710 100644 --- a/internal/runtime/runtime_factory_test.go +++ b/internal/runtime/runtime_factory_test.go @@ -30,6 +30,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/test" ) @@ -165,3 +166,181 @@ func TestFactoryMethod(t *testing.T) { }) } } + +func TestNewSpecModifier(t *testing.T) { + logger, _ := testlog.NewNullLogger() + driver := root.New( + root.WithDriverRoot("/nvidia/driver/root"), + ) + testCases := []struct { + description string + config *config.Config + spec *specs.Spec + expectedSpec *specs.Spec + }{ + { + description: "csv mode removes nvidia-container-runtime-hook", + config: &config.Config{ + NVIDIAContainerRuntimeConfig: config.RuntimeConfig{ + Mode: "csv", + }, + }, + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-runtime-hook", + Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"}, + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: nil, + }, + }, + }, + { + description: "csv mode removes nvidia-container-toolkit", + config: &config.Config{ + NVIDIAContainerRuntimeConfig: config.RuntimeConfig{ + Mode: "csv", + }, + }, + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-toolkit", + Args: []string{"/path/to/nvidia-container-toolkit", "prestart"}, + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: nil, + }, + }, + }, + { + description: "cdi mode removes nvidia-container-runtime-hook", + config: &config.Config{ + NVIDIAContainerRuntimeConfig: config.RuntimeConfig{ + Mode: "cdi", + }, + }, + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-runtime-hook", + Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"}, + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: nil, + }, + }, + }, + { + description: "cdi mode removes nvidia-container-toolkit", + config: &config.Config{ + NVIDIAContainerRuntimeConfig: config.RuntimeConfig{ + Mode: "cdi", + }, + }, + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-toolkit", + Args: []string{"/path/to/nvidia-container-toolkit", "prestart"}, + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: nil, + }, + }, + }, + { + description: "legacy mode keeps nvidia-container-runtime-hook", + config: &config.Config{ + NVIDIAContainerRuntimeConfig: config.RuntimeConfig{ + Mode: "legacy", + }, + }, + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-runtime-hook", + Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"}, + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-runtime-hook", + Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"}, + }, + }, + }, + }, + }, + { + description: "legacy mode keeps nvidia-container-toolkit", + config: &config.Config{ + NVIDIAContainerRuntimeConfig: config.RuntimeConfig{ + Mode: "legacy", + }, + }, + spec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-toolkit", + Args: []string{"/path/to/nvidia-container-toolkit", "prestart"}, + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{ + { + Path: "/path/to/nvidia-container-toolkit", + Args: []string{"/path/to/nvidia-container-toolkit", "prestart"}, + }, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + spec := &oci.SpecMock{ + LoadFunc: func() (*specs.Spec, error) { + return tc.spec, nil + }, + } + m, err := newSpecModifier(logger, tc.config, spec, driver) + require.NoError(t, err) + + err = m.Modify(tc.spec) + require.NoError(t, err) + require.EqualValues(t, tc.expectedSpec, tc.spec) + }) + } +} diff --git a/tests/e2e/nvidia-container-toolkit_test.go b/tests/e2e/nvidia-container-toolkit_test.go index 609fd18b..1895aff0 100644 --- a/tests/e2e/nvidia-container-toolkit_test.go +++ b/tests/e2e/nvidia-container-toolkit_test.go @@ -74,6 +74,12 @@ var _ = Describe("docker", Ordered, func() { Expect(containerOutput).To(Equal(hostOutput)) }) + It("should support automatic CDI spec generation with the --gpus flag", func(ctx context.Context) { + containerOutput, _, err := r.Run("docker run --rm -i --gpus=all --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all ubuntu nvidia-smi -L") + Expect(err).ToNot(HaveOccurred()) + Expect(containerOutput).To(Equal(hostOutput)) + }) + It("should support the --gpus flag using the nvidia-container-runtime", func(ctx context.Context) { containerOutput, _, err := r.Run("docker run --rm -i --runtime=nvidia --gpus all ubuntu nvidia-smi -L") Expect(err).ToNot(HaveOccurred())