diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index 81225d14..0a834d0d 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -45,17 +45,7 @@ const ( // NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. // The modifications are defined by CSV MountSpecs. -func NewCSVModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec(rawSpec) - if err != nil { - return nil, err - } - +func NewCSVModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { logger.Infof("No modification required; no devices requested") return nil, nil diff --git a/internal/modifier/csv_test.go b/internal/modifier/csv_test.go index b02180f2..3cd6ce56 100644 --- a/internal/modifier/csv_test.go +++ b/internal/modifier/csv_test.go @@ -17,11 +17,10 @@ package modifier import ( - "fmt" "testing" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" - "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/opencontainers/runtime-spec/specs-go" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" @@ -31,54 +30,32 @@ func TestNewCSVModifier(t *testing.T) { logger, _ := testlog.NewNullLogger() testCases := []struct { - description string - cfg *config.Config - spec oci.Spec - visibleDevices string - expectedError error - expectedNil bool + description string + cfg *config.Config + image image.CUDA + expectedError error + expectedNil bool }{ { - description: "spec load error returns error", - spec: &oci.SpecMock{ - LoadFunc: func() (*specs.Spec, error) { - return nil, fmt.Errorf("load failed") - }, - }, - expectedError: fmt.Errorf("load failed"), + description: "visible devices not set returns nil", + image: image.CUDA{}, + expectedNil: true, }, { - description: "visible devices not set returns nil", - visibleDevices: "NOT_SET", - expectedNil: true, + description: "visible devices empty returns nil", + image: image.CUDA{"NVIDIA_VISIBLE_DEVICES": ""}, + expectedNil: true, }, { - description: "visible devices empty returns nil", - visibleDevices: "", - expectedNil: true, - }, - { - description: "visible devices 'void' returns nil", - visibleDevices: "void", - expectedNil: true, + description: "visible devices 'void' returns nil", + image: image.CUDA{"NVIDIA_VISIBLE_DEVICES": "void"}, + expectedNil: true, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - spec := tc.spec - if spec == nil { - spec = &oci.SpecMock{ - LookupEnvFunc: func(s string) (string, bool) { - if tc.visibleDevices != "NOT_SET" && s == visibleDevicesEnvvar { - return tc.visibleDevices, true - } - return "", false - }, - } - } - - m, err := NewCSVModifier(logger, tc.cfg, spec) + m, err := NewCSVModifier(logger, tc.cfg, tc.image) if tc.expectedError != nil { require.Error(t, err) } else { diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index 2ff3a1f5..81ab94f1 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -73,7 +73,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp } mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) - modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec) + modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, image) if err != nil { return nil, err } @@ -106,12 +106,12 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return modifiers, nil } -func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { +func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec, image image.CUDA) (oci.SpecModifier, error) { switch mode { case "legacy": return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil case "csv": - return modifier.NewCSVModifier(logger, cfg, ociSpec) + return modifier.NewCSVModifier(logger, cfg, image) case "cdi": return modifier.NewCDIModifier(logger, cfg, ociSpec) }