diff --git a/cmd/nvidia-container-runtime/modifier/csv.go b/cmd/nvidia-container-runtime/modifier/csv.go index b9210fcd..06a1306d 100644 --- a/cmd/nvidia-container-runtime/modifier/csv.go +++ b/cmd/nvidia-container-runtime/modifier/csv.go @@ -24,10 +24,9 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" - "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" - "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" ) @@ -104,33 +103,22 @@ func NewCSVModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) d := discover.NewList(csvDiscoverer, ldcacheUpdateHook, createSymlinksHook) - return newModifierFromDiscoverer(logger, d) + return newCSVModifierFromDiscoverer(logger, d) } -// newModifierFromDiscoverer created a modifier that aplies the discovered -// modifications to an OCI spec if require by the runtime wrapper. -func newModifierFromDiscoverer(logger *logrus.Logger, d discover.Discover) (oci.SpecModifier, error) { - m := csvMode{ - logger: logger, - discoverer: d, - } - return &m, nil -} - -// Modify applies the required modifications to the incomming OCI spec. These modifications -// are applied in-place. -func (m csvMode) Modify(spec *specs.Spec) error { - err := nvidiaContainerRuntimeHookRemover{m.logger}.Modify(spec) +// newCSVModifierFromDiscoverer is used to test with dependency injection +func newCSVModifierFromDiscoverer(logger *logrus.Logger, d discover.Discover) (oci.SpecModifier, error) { + discoverModifier, err := modifier.NewModifierFromDiscoverer(logger, d) if err != nil { - return fmt.Errorf("failed to remove existing hooks: %v", err) + return nil, fmt.Errorf("failed to construct modifier: %v", err) } - specEdits, err := edits.NewSpecEdits(m.logger, m.discoverer) - if err != nil { - return fmt.Errorf("failed to get required container edits: %v", err) - } + modifiers := modifier.Merge( + nvidiaContainerRuntimeHookRemover{logger}, + discoverModifier, + ) - return specEdits.Modify(spec) + return modifiers, nil } func checkRequirements(logger *logrus.Logger, image *image.CUDA) error { diff --git a/cmd/nvidia-container-runtime/modifier/csv_test.go b/cmd/nvidia-container-runtime/modifier/csv_test.go index e8cbf3e3..ac3a0dee 100644 --- a/cmd/nvidia-container-runtime/modifier/csv_test.go +++ b/cmd/nvidia-container-runtime/modifier/csv_test.go @@ -268,7 +268,7 @@ func TestExperimentalModifier(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - m, err := newModifierFromDiscoverer(logger, tc.discover) + m, err := newCSVModifierFromDiscoverer(logger, tc.discover) require.NoError(t, err) err = m.Modify(tc.spec)