diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index d6ba14ae..ff507f3b 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -80,11 +80,6 @@ func TestBadInput(t *testing.T) { t.Fatal(err) } - cmdRun := exec.Command(nvidiaRuntime, "run", "--bundle") - t.Logf("executing: %s\n", strings.Join(cmdRun.Args, " ")) - output, err := cmdRun.CombinedOutput() - require.Errorf(t, err, "runtime should return an error", "output=%v", string(output)) - cmdCreate := exec.Command(nvidiaRuntime, "create", "--bundle") t.Logf("executing: %s\n", strings.Join(cmdCreate.Args, " ")) err = cmdCreate.Run() diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go index 43d03769..ca06ca17 100644 --- a/cmd/nvidia-container-runtime/runtime_factory.go +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -33,10 +33,6 @@ const ( // newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv []string) (oci.Runtime, error) { - ociSpec, err := oci.NewSpec(logger, argv) - if err != nil { - return nil, fmt.Errorf("error constructing OCI specification: %v", err) - } lowLevelRuntimeCandidates := []string{dockerRuncExecutableName, runcExecutableName} lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, lowLevelRuntimeCandidates) @@ -44,7 +40,17 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [ return nil, fmt.Errorf("error constructing low-level runtime: %v", err) } - specModifier, err := newSpecModifier(logger, cfg, ociSpec) + if !oci.HasCreateSubcommand(argv) { + logger.Debugf("Skipping modifier for non-create subcommand") + return lowLevelRuntime, nil + } + + ociSpec, err := oci.NewSpec(logger, argv) + if err != nil { + return nil, fmt.Errorf("error constructing OCI specification: %v", err) + } + + specModifier, err := newSpecModifier(logger, cfg, ociSpec, argv) if err != nil { return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err) } @@ -61,7 +67,7 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [ } // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. -func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { +func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { if !cfg.NVIDIAContainerRuntimeConfig.Experimental { return modifier.NewStableRuntimeModifier(logger), nil }