diff --git a/cmd/nvidia-ctk/runtime/configure/configure.go b/cmd/nvidia-ctk/runtime/configure/configure.go index d2528853..aa8a496c 100644 --- a/cmd/nvidia-ctk/runtime/configure/configure.go +++ b/cmd/nvidia-ctk/runtime/configure/configure.go @@ -163,7 +163,7 @@ func (m command) build() *cli.Command { }, &cli.BoolFlag{ Name: "cdi.enabled", - Aliases: []string{"cdi.enable"}, + Aliases: []string{"cdi.enable", "enable-cdi"}, Usage: "Enable CDI in the configured runtime", Destination: &config.cdi.enabled, }, @@ -292,9 +292,8 @@ func (m command) configureConfigFile(c *cli.Context, config *config) error { return fmt.Errorf("unable to update config: %v", err) } - err = enableCDI(config, cfg) - if err != nil { - return fmt.Errorf("failed to enable CDI in %s: %w", config.runtime, err) + if config.cdi.enabled { + cfg.EnableCDI() } outputPath := config.getOutputConfigPath() @@ -354,19 +353,3 @@ func (m *command) configureOCIHook(c *cli.Context, config *config) error { } return nil } - -// enableCDI enables the use of CDI in the corresponding container engine -func enableCDI(config *config, cfg engine.Interface) error { - if !config.cdi.enabled { - return nil - } - switch config.runtime { - case "containerd": - cfg.Set("enable_cdi", true) - case "docker": - cfg.Set("features", map[string]bool{"cdi": true}) - default: - return fmt.Errorf("enabling CDI in %s is not supported", config.runtime) - } - return nil -} diff --git a/pkg/config/engine/api.go b/pkg/config/engine/api.go index 8c7d1b50..d27f09e9 100644 --- a/pkg/config/engine/api.go +++ b/pkg/config/engine/api.go @@ -20,10 +20,10 @@ package engine type Interface interface { AddRuntime(string, string, bool) error DefaultRuntime() string + EnableCDI() GetRuntimeConfig(string) (RuntimeConfig, error) RemoveRuntime(string) error Save(string) (int64, error) - Set(string, interface{}) String() string } diff --git a/pkg/config/engine/containerd/config.go b/pkg/config/engine/containerd/config.go index 52a336ba..62468fe5 100644 --- a/pkg/config/engine/containerd/config.go +++ b/pkg/config/engine/containerd/config.go @@ -96,13 +96,6 @@ func (c *Config) getRuntimeAnnotations(path []string) ([]string, error) { return annotations, nil } -// Set sets the specified containerd option. -func (c *Config) Set(key string, value interface{}) { - config := *c.Tree - config.SetPath([]string{"plugins", c.CRIRuntimePluginName, key}, value) - *c.Tree = config -} - // DefaultRuntime returns the default runtime for the cri-o config func (c Config) DefaultRuntime() string { if runtime, ok := c.GetPath([]string{"plugins", c.CRIRuntimePluginName, "containerd", "default_runtime_name"}).(string); ok { @@ -111,6 +104,13 @@ func (c Config) DefaultRuntime() string { return "" } +// EnableCDI sets the enable_cdi field in the Containerd config to true. +func (c *Config) EnableCDI() { + config := *c.Tree + config.SetPath([]string{"plugins", c.CRIRuntimePluginName, "enable_cdi"}, true) + *c.Tree = config +} + // RemoveRuntime removes a runtime from the docker config func (c *Config) RemoveRuntime(name string) error { if c == nil || c.Tree == nil { diff --git a/pkg/config/engine/containerd/config_v1.go b/pkg/config/engine/containerd/config_v1.go index 10b6d087..2189a8de 100644 --- a/pkg/config/engine/containerd/config_v1.go +++ b/pkg/config/engine/containerd/config_v1.go @@ -143,13 +143,6 @@ func (c *ConfigV1) RemoveRuntime(name string) error { return nil } -// Set sets the specified containerd option. -func (c *ConfigV1) Set(key string, value interface{}) { - config := *c.Tree - config.SetPath([]string{"plugins", "cri", "containerd", key}, value) - *c.Tree = config -} - // Save writes the config to a file func (c ConfigV1) Save(path string) (int64, error) { return (Config)(c).Save(path) @@ -165,3 +158,9 @@ func (c *ConfigV1) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) { tree: runtimeData, }, nil } + +func (c *ConfigV1) EnableCDI() { + config := *c.Tree + config.SetPath([]string{"plugins", "cri", "containerd", "enable_cdi"}, true) + *c.Tree = config +} diff --git a/pkg/config/engine/crio/crio.go b/pkg/config/engine/crio/crio.go index 3d5629d7..c0cc60be 100644 --- a/pkg/config/engine/crio/crio.go +++ b/pkg/config/engine/crio/crio.go @@ -153,6 +153,9 @@ func (c *Config) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) { }, nil } +// EnableCDI is a no-op for CRI-O since it always enabled where supported. +func (c *Config) EnableCDI() {} + // CommandLineSource returns the CLI-based crio config loader func CommandLineSource(hostRoot string) toml.Loader { return toml.LoadFirst( diff --git a/pkg/config/engine/docker/docker.go b/pkg/config/engine/docker/docker.go index eda700d0..86512df9 100644 --- a/pkg/config/engine/docker/docker.go +++ b/pkg/config/engine/docker/docker.go @@ -103,6 +103,24 @@ func (c Config) DefaultRuntime() string { return r } +// EnableCDI sets features.cdi to true in the docker config. +func (c *Config) EnableCDI() { + if c == nil { + return + } + config := *c + + features, ok := config["features"].(map[string]bool) + if !ok { + features = make(map[string]bool) + } + features["cdi"] = true + + config["features"] = features + + *c = config +} + // RemoveRuntime removes a runtime from the docker config func (c *Config) RemoveRuntime(name string) error { if c == nil { @@ -132,11 +150,6 @@ func (c *Config) RemoveRuntime(name string) error { return nil } -// Set sets the specified docker option -func (c *Config) Set(key string, value interface{}) { - (*c)[key] = value -} - // Save writes the config to the specified path func (c Config) Save(path string) (int64, error) { output, err := json.MarshalIndent(c, "", " ") diff --git a/tools/container/container.go b/tools/container/container.go index c2c50c5b..4b694c4c 100644 --- a/tools/container/container.go +++ b/tools/container/container.go @@ -36,8 +36,10 @@ const ( // Options defines the shared options for the CLIs to configure containers runtimes. type Options struct { - Config string - Socket string + Config string + Socket string + // EnabledCDI indicates whether CDI should be enabled. + EnableCDI bool RuntimeName string RuntimeDir string SetAsDefault bool @@ -111,6 +113,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error { } } + if o.EnableCDI { + cfg.EnableCDI() + } + return nil } diff --git a/tools/container/nvidia-toolkit/run.go b/tools/container/nvidia-toolkit/run.go index 265814a5..6e1edaef 100644 --- a/tools/container/nvidia-toolkit/run.go +++ b/tools/container/nvidia-toolkit/run.go @@ -129,14 +129,14 @@ func main() { log.Infof("Completed %v", c.Name) } -func validateFlags(_ *cli.Context, o *options) error { +func validateFlags(c *cli.Context, o *options) error { if filepath.Base(o.pidFile) != toolkitPidFilename { return fmt.Errorf("invalid toolkit.pid path %v", o.pidFile) } if err := toolkit.ValidateOptions(&o.toolkitOptions, o.toolkitRoot()); err != nil { return err } - if err := runtime.ValidateOptions(&o.runtimeOptions, o.runtime, o.toolkitRoot()); err != nil { + if err := runtime.ValidateOptions(c, &o.runtimeOptions, o.runtime, o.toolkitRoot(), &o.toolkitOptions); err != nil { return err } return nil diff --git a/tools/container/runtime/containerd/config_v1_test.go b/tools/container/runtime/containerd/config_v1_test.go index 7042744f..90b24972 100644 --- a/tools/container/runtime/containerd/config_v1_test.go +++ b/tools/container/runtime/containerd/config_v1_test.go @@ -410,6 +410,51 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV1EnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + enableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + enableCDI: false, + expectedEnableCDIValue: nil, + }, + { + enableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) { + o := &container.Options{ + EnableCDI: tc.enableCDI, + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + } + + cfg, err := toml.Empty.Load() + require.NoError(t, err) + + v1 := &containerd.ConfigV1{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + } + + err = o.UpdateConfig(v1) + require.NoError(t, err) + + enableCDIValue := v1.GetPath([]string{"plugins", "cri", "containerd", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue) + }) + } +} + func TestRevertV1Config(t *testing.T) { logger, _ := testlog.NewNullLogger() testCases := []struct { diff --git a/tools/container/runtime/containerd/config_v2_test.go b/tools/container/runtime/containerd/config_v2_test.go index 92eea1fa..ff747a98 100644 --- a/tools/container/runtime/containerd/config_v2_test.go +++ b/tools/container/runtime/containerd/config_v2_test.go @@ -366,6 +366,53 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) { } } +func TestUpdateV2ConfigEnableCDI(t *testing.T) { + logger, _ := testlog.NewNullLogger() + const runtimeDir = "/test/runtime/dir" + + testCases := []struct { + enableCDI bool + expectedEnableCDIValue interface{} + }{ + {}, + { + enableCDI: false, + expectedEnableCDIValue: nil, + }, + { + enableCDI: true, + expectedEnableCDIValue: true, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) { + o := &container.Options{ + EnableCDI: tc.enableCDI, + RuntimeName: "nvidia", + RuntimeDir: runtimeDir, + SetAsDefault: false, + } + + cfg, err := toml.LoadMap(map[string]interface{}{}) + require.NoError(t, err) + + v2 := &containerd.Config{ + Logger: logger, + Tree: cfg, + RuntimeType: runtimeType, + CRIRuntimePluginName: "io.containerd.grpc.v1.cri", + } + + err = o.UpdateConfig(v2) + require.NoError(t, err) + + enableCDIValue := cfg.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "enable_cdi"}) + require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue) + }) + } +} + func TestRevertV2Config(t *testing.T) { logger, _ := testlog.NewNullLogger() diff --git a/tools/container/runtime/runtime.go b/tools/container/runtime/runtime.go index 865f92e8..c6aed501 100644 --- a/tools/container/runtime/runtime.go +++ b/tools/container/runtime/runtime.go @@ -25,6 +25,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime/containerd" "github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime/crio" "github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime/docker" + "github.com/NVIDIA/nvidia-container-toolkit/tools/container/toolkit" ) const ( @@ -66,6 +67,12 @@ func Flags(opts *Options) []cli.Flag { Destination: &opts.RestartMode, EnvVars: []string{"RUNTIME_RESTART_MODE"}, }, + &cli.BoolFlag{ + Name: "enable-cdi-in-runtime", + Usage: "Enable CDI in the configured runt ime", + Destination: &opts.EnableCDI, + EnvVars: []string{"RUNTIME_ENABLE_CDI"}, + }, &cli.StringFlag{ Name: "host-root", Usage: "Specify the path to the host root to be used when restarting the runtime using systemd", @@ -98,10 +105,14 @@ func Flags(opts *Options) []cli.Flag { } // ValidateOptions checks whether the specified options are valid -func ValidateOptions(opts *Options, runtime string, toolkitRoot string) error { +func ValidateOptions(c *cli.Context, opts *Options, runtime string, toolkitRoot string, to *toolkit.Options) error { // We set this option here to ensure that it is available in future calls. opts.RuntimeDir = toolkitRoot + if !c.IsSet("enable-cdi-in-runtime") { + opts.EnableCDI = to.CDI.Enabled + } + // Apply the runtime-specific config changes. switch runtime { case containerd.Name: diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index 43e68ca5..c8530f20 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -48,6 +48,14 @@ const ( toolkitPidFilename = "toolkit.pid" ) +type cdiOptions struct { + Enabled bool + outputDir string + kind string + vendor string + class string +} + type Options struct { DriverRoot string DevRoot string @@ -67,11 +75,8 @@ type Options struct { ContainerCLIDebug string - cdiEnabled bool - cdiOutputDir string - cdiKind string - cdiVendor string - cdiClass string + // CDI stores the CDI options for the toolkit. + CDI cdiOptions createDeviceNodes cli.StringSlice @@ -174,21 +179,21 @@ func Flags(opts *Options) []cli.Flag { Name: "cdi-enabled", Aliases: []string{"enable-cdi"}, Usage: "enable the generation of a CDI specification", - Destination: &opts.cdiEnabled, + Destination: &opts.CDI.Enabled, EnvVars: []string{"CDI_ENABLED", "ENABLE_CDI"}, }, &cli.StringFlag{ Name: "cdi-output-dir", Usage: "the directory where the CDI output files are to be written. If this is set to '', no CDI specification is generated.", Value: "/var/run/cdi", - Destination: &opts.cdiOutputDir, + Destination: &opts.CDI.outputDir, EnvVars: []string{"CDI_OUTPUT_DIR"}, }, &cli.StringFlag{ Name: "cdi-kind", Usage: "the vendor string to use for the generated CDI specification", Value: "management.nvidia.com/gpu", - Destination: &opts.cdiKind, + Destination: &opts.CDI.kind, EnvVars: []string{"CDI_KIND"}, }, &cli.BoolFlag{ @@ -221,19 +226,19 @@ func ValidateOptions(opts *Options, toolkitRoot string) error { return fmt.Errorf("invalid --toolkit-root option: %v", toolkitRoot) } - vendor, class := parser.ParseQualifier(opts.cdiKind) + vendor, class := parser.ParseQualifier(opts.CDI.kind) if err := parser.ValidateVendorName(vendor); err != nil { return fmt.Errorf("invalid CDI vendor name: %v", err) } if err := parser.ValidateClassName(class); err != nil { return fmt.Errorf("invalid CDI class name: %v", err) } - opts.cdiVendor = vendor - opts.cdiClass = class + opts.CDI.vendor = vendor + opts.CDI.class = class - if opts.cdiEnabled && opts.cdiOutputDir == "" { + if opts.CDI.Enabled && opts.CDI.outputDir == "" { log.Warning("Skipping CDI spec generation (no output directory specified)") - opts.cdiEnabled = false + opts.CDI.Enabled = false } isDisabled := false @@ -246,7 +251,7 @@ func ValidateOptions(opts *Options, toolkitRoot string) error { break } } - if !opts.cdiEnabled && !isDisabled { + if !opts.CDI.Enabled && !isDisabled { log.Info("disabling device node creation since --cdi-enabled=false") isDisabled = true } @@ -761,7 +766,7 @@ func createDeviceNodes(opts *Options) error { // generateCDISpec generates a CDI spec for use in management containers func generateCDISpec(opts *Options, nvidiaCDIHookPath string) error { - if !opts.cdiEnabled { + if !opts.CDI.Enabled { return nil } log.Info("Generating CDI spec for management containers") @@ -770,8 +775,8 @@ func generateCDISpec(opts *Options, nvidiaCDIHookPath string) error { nvcdi.WithDriverRoot(opts.DriverRootCtrPath), nvcdi.WithDevRoot(opts.DevRootCtrPath), nvcdi.WithNVIDIACDIHookPath(nvidiaCDIHookPath), - nvcdi.WithVendor(opts.cdiVendor), - nvcdi.WithClass(opts.cdiClass), + nvcdi.WithVendor(opts.CDI.vendor), + nvcdi.WithClass(opts.CDI.class), ) if err != nil { return fmt.Errorf("failed to create CDI library for management containers: %v", err) @@ -796,7 +801,7 @@ func generateCDISpec(opts *Options, nvidiaCDIHookPath string) error { if err != nil { return fmt.Errorf("failed to generate CDI name for management containers: %v", err) } - err = spec.Save(filepath.Join(opts.cdiOutputDir, name)) + err = spec.Save(filepath.Join(opts.CDI.outputDir, name)) if err != nil { return fmt.Errorf("failed to save CDI spec for management containers: %v", err) }