diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index 80f94abc..a417907e 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -61,10 +61,15 @@ func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv [ // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { - modeModifier, err := newModeModifier(logger, cfg, ociSpec, argv) + mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) + modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, argv) if err != nil { return nil, err } + // For CDI mode we make no additional modifications. + if mode == "cdi" { + return modeModifier, nil + } graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, ociSpec) if err != nil { @@ -90,8 +95,8 @@ func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec return modifiers, nil } -func newModeModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { - switch info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) { +func newModeModifier(logger *logrus.Logger, mode string, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { + switch mode { case "legacy": return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil case "csv":