diff --git a/cmd/nvidia-ctk-installer/container/container.go b/cmd/nvidia-ctk-installer/container/container.go index 68e5f1f8..5e838608 100644 --- a/cmd/nvidia-ctk-installer/container/container.go +++ b/cmd/nvidia-ctk-installer/container/container.go @@ -36,8 +36,10 @@ const ( // Options defines the shared options for the CLIs to configure containers runtimes. type Options struct { - Config string - Socket string + Config string + Socket string + // EnabledCDI indicates whether CDI should be enabled. + EnableCDI bool RuntimeName string RuntimeDir string SetAsDefault bool @@ -111,6 +113,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error { } } + if o.EnableCDI { + cfg.EnableCDI() + } + return nil } diff --git a/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v1_test.go b/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v1_test.go index ea06b555..862d5992 100644 --- a/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v1_test.go +++ b/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v1_test.go @@ -410,6 +410,51 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV1EnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + enableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + enableCDI: false, + expectedEnableCDIValue: nil, + }, + { + enableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) { + o := &container.Options{ + EnableCDI: tc.enableCDI, + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + } + + cfg, err := toml.Empty.Load() + require.NoError(t, err) + + v1 := &containerd.ConfigV1{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + } + + err = o.UpdateConfig(v1) + require.NoError(t, err) + + enableCDIValue := v1.GetPath([]string{"plugins", "cri", "containerd", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue) + }) + } +} + func TestRevertV1Config(t *testing.T) { logger, _ := testlog.NewNullLogger() testCases := []struct { diff --git a/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v2_test.go b/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v2_test.go index e206c59d..a6570e8f 100644 --- a/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v2_test.go +++ b/cmd/nvidia-ctk-installer/container/runtime/containerd/config_v2_test.go @@ -366,6 +366,53 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV2ConfigEnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + enableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + enableCDI: false, + expectedEnableCDIValue: nil, + }, + { + enableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) { + o := &container.Options{ + EnableCDI: tc.enableCDI, + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + SetAsDefault: false, + } + + cfg, err := toml.LoadMap(map[string]interface{}{}) + require.NoError(t, err) + + v2 := &containerd.Config{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + CRIRuntimePluginName: "io.containerd.grpc.v1.cri", + } + + err = o.UpdateConfig(v2) + require.NoError(t, err) + + enableCDIValue := cfg.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue) + }) + } +} + func TestRevertV2Config(t *testing.T) { logger, _ := testlog.NewNullLogger() diff --git a/cmd/nvidia-ctk-installer/container/runtime/runtime.go b/cmd/nvidia-ctk-installer/container/runtime/runtime.go index cbe68830..480fdc61 100644 --- a/cmd/nvidia-ctk-installer/container/runtime/runtime.go +++ b/cmd/nvidia-ctk-installer/container/runtime/runtime.go @@ -66,6 +66,12 @@ func Flags(opts *Options) []cli.Flag { Destination: &opts.RestartMode, EnvVars: []string{"RUNTIME_RESTART_MODE"}, }, + &cli.BoolFlag{ + Name: "enable-cdi-in-runtime", + Usage: "Enable CDI in the configured runt ime", + Destination: &opts.EnableCDI, + EnvVars: []string{"RUNTIME_ENABLE_CDI"}, + }, &cli.StringFlag{ Name: "host-root", Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",