Merge pull request #953 from elezar/backport-enable-cdi-container

Add option in toolkit container to enable CDI in runtime
This commit is contained in:
Evan Lezar 2025-03-06 10:50:43 +02:00 committed by GitHub
commit 997f23cf11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 175 additions and 63 deletions

View File

@ -163,7 +163,7 @@ func (m command) build() *cli.Command {
},
&cli.BoolFlag{
Name: "cdi.enabled",
Aliases: []string{"cdi.enable"},
Aliases: []string{"cdi.enable", "enable-cdi"},
Usage: "Enable CDI in the configured runtime",
Destination: &config.cdi.enabled,
},
@ -292,9 +292,8 @@ func (m command) configureConfigFile(c *cli.Context, config *config) error {
return fmt.Errorf("unable to update config: %v", err)
}
err = enableCDI(config, cfg)
if err != nil {
return fmt.Errorf("failed to enable CDI in %s: %w", config.runtime, err)
if config.cdi.enabled {
cfg.EnableCDI()
}
outputPath := config.getOutputConfigPath()
@ -354,19 +353,3 @@ func (m *command) configureOCIHook(c *cli.Context, config *config) error {
}
return nil
}
// enableCDI enables the use of CDI in the corresponding container engine
func enableCDI(config *config, cfg engine.Interface) error {
if !config.cdi.enabled {
return nil
}
switch config.runtime {
case "containerd":
cfg.Set("enable_cdi", true)
case "docker":
cfg.Set("features", map[string]bool{"cdi": true})
default:
return fmt.Errorf("enabling CDI in %s is not supported", config.runtime)
}
return nil
}

View File

@ -20,10 +20,10 @@ package engine
type Interface interface {
AddRuntime(string, string, bool) error
DefaultRuntime() string
EnableCDI()
GetRuntimeConfig(string) (RuntimeConfig, error)
RemoveRuntime(string) error
Save(string) (int64, error)
Set(string, interface{})
String() string
}

View File

@ -96,13 +96,6 @@ func (c *Config) getRuntimeAnnotations(path []string) ([]string, error) {
return annotations, nil
}
// Set sets the specified containerd option.
func (c *Config) Set(key string, value interface{}) {
config := *c.Tree
config.SetPath([]string{"plugins", c.CRIRuntimePluginName, key}, value)
*c.Tree = config
}
// DefaultRuntime returns the default runtime for the cri-o config
func (c Config) DefaultRuntime() string {
if runtime, ok := c.GetPath([]string{"plugins", c.CRIRuntimePluginName, "containerd", "default_runtime_name"}).(string); ok {
@ -111,6 +104,13 @@ func (c Config) DefaultRuntime() string {
return ""
}
// EnableCDI sets the enable_cdi field in the Containerd config to true.
func (c *Config) EnableCDI() {
config := *c.Tree
config.SetPath([]string{"plugins", c.CRIRuntimePluginName, "enable_cdi"}, true)
*c.Tree = config
}
// RemoveRuntime removes a runtime from the docker config
func (c *Config) RemoveRuntime(name string) error {
if c == nil || c.Tree == nil {

View File

@ -143,13 +143,6 @@ func (c *ConfigV1) RemoveRuntime(name string) error {
return nil
}
// Set sets the specified containerd option.
func (c *ConfigV1) Set(key string, value interface{}) {
config := *c.Tree
config.SetPath([]string{"plugins", "cri", "containerd", key}, value)
*c.Tree = config
}
// Save writes the config to a file
func (c ConfigV1) Save(path string) (int64, error) {
return (Config)(c).Save(path)
@ -165,3 +158,9 @@ func (c *ConfigV1) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) {
tree: runtimeData,
}, nil
}
func (c *ConfigV1) EnableCDI() {
config := *c.Tree
config.SetPath([]string{"plugins", "cri", "containerd", "enable_cdi"}, true)
*c.Tree = config
}

View File

@ -153,6 +153,9 @@ func (c *Config) GetRuntimeConfig(name string) (engine.RuntimeConfig, error) {
}, nil
}
// EnableCDI is a no-op for CRI-O since it always enabled where supported.
func (c *Config) EnableCDI() {}
// CommandLineSource returns the CLI-based crio config loader
func CommandLineSource(hostRoot string) toml.Loader {
return toml.LoadFirst(

View File

@ -103,6 +103,24 @@ func (c Config) DefaultRuntime() string {
return r
}
// EnableCDI sets features.cdi to true in the docker config.
func (c *Config) EnableCDI() {
if c == nil {
return
}
config := *c
features, ok := config["features"].(map[string]bool)
if !ok {
features = make(map[string]bool)
}
features["cdi"] = true
config["features"] = features
*c = config
}
// RemoveRuntime removes a runtime from the docker config
func (c *Config) RemoveRuntime(name string) error {
if c == nil {
@ -132,11 +150,6 @@ func (c *Config) RemoveRuntime(name string) error {
return nil
}
// Set sets the specified docker option
func (c *Config) Set(key string, value interface{}) {
(*c)[key] = value
}
// Save writes the config to the specified path
func (c Config) Save(path string) (int64, error) {
output, err := json.MarshalIndent(c, "", " ")

View File

@ -36,8 +36,10 @@ const (
// Options defines the shared options for the CLIs to configure containers runtimes.
type Options struct {
Config string
Socket string
Config string
Socket string
// EnabledCDI indicates whether CDI should be enabled.
EnableCDI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
@ -111,6 +113,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
}
}
if o.EnableCDI {
cfg.EnableCDI()
}
return nil
}

View File

@ -129,14 +129,14 @@ func main() {
log.Infof("Completed %v", c.Name)
}
func validateFlags(_ *cli.Context, o *options) error {
func validateFlags(c *cli.Context, o *options) error {
if filepath.Base(o.pidFile) != toolkitPidFilename {
return fmt.Errorf("invalid toolkit.pid path %v", o.pidFile)
}
if err := toolkit.ValidateOptions(&o.toolkitOptions, o.toolkitRoot()); err != nil {
return err
}
if err := runtime.ValidateOptions(&o.runtimeOptions, o.runtime, o.toolkitRoot()); err != nil {
if err := runtime.ValidateOptions(c, &o.runtimeOptions, o.runtime, o.toolkitRoot(), &o.toolkitOptions); err != nil {
return err
}
return nil

View File

@ -410,6 +410,51 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
}
}
func TestUpdateV1EnableCDI(t *testing.T) {
logger, _ := testlog.NewNullLogger()
const runtimeDir = "/test/runtime/dir"
testCases := []struct {
enableCDI bool
expectedEnableCDIValue interface{}
}{
{},
{
enableCDI: false,
expectedEnableCDIValue: nil,
},
{
enableCDI: true,
expectedEnableCDIValue: true,
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) {
o := &container.Options{
EnableCDI: tc.enableCDI,
RuntimeName: "nvidia",
RuntimeDir: runtimeDir,
}
cfg, err := toml.Empty.Load()
require.NoError(t, err)
v1 := &containerd.ConfigV1{
Logger: logger,
Tree: cfg,
RuntimeType: runtimeType,
}
err = o.UpdateConfig(v1)
require.NoError(t, err)
enableCDIValue := v1.GetPath([]string{"plugins", "cri", "containerd", "enable_cdi"})
require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue)
})
}
}
func TestRevertV1Config(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {

View File

@ -366,6 +366,53 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
}
}
func TestUpdateV2ConfigEnableCDI(t *testing.T) {
logger, _ := testlog.NewNullLogger()
const runtimeDir = "/test/runtime/dir"
testCases := []struct {
enableCDI bool
expectedEnableCDIValue interface{}
}{
{},
{
enableCDI: false,
expectedEnableCDIValue: nil,
},
{
enableCDI: true,
expectedEnableCDIValue: true,
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) {
o := &container.Options{
EnableCDI: tc.enableCDI,
RuntimeName: "nvidia",
RuntimeDir: runtimeDir,
SetAsDefault: false,
}
cfg, err := toml.LoadMap(map[string]interface{}{})
require.NoError(t, err)
v2 := &containerd.Config{
Logger: logger,
Tree: cfg,
RuntimeType: runtimeType,
CRIRuntimePluginName: "io.containerd.grpc.v1.cri",
}
err = o.UpdateConfig(v2)
require.NoError(t, err)
enableCDIValue := cfg.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "enable_cdi"})
require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue)
})
}
}
func TestRevertV2Config(t *testing.T) {
logger, _ := testlog.NewNullLogger()

View File

@ -25,6 +25,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime/crio"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime/docker"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/toolkit"
)
const (
@ -66,6 +67,12 @@ func Flags(opts *Options) []cli.Flag {
Destination: &opts.RestartMode,
EnvVars: []string{"RUNTIME_RESTART_MODE"},
},
&cli.BoolFlag{
Name: "enable-cdi-in-runtime",
Usage: "Enable CDI in the configured runt ime",
Destination: &opts.EnableCDI,
EnvVars: []string{"RUNTIME_ENABLE_CDI"},
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",
@ -98,10 +105,14 @@ func Flags(opts *Options) []cli.Flag {
}
// ValidateOptions checks whether the specified options are valid
func ValidateOptions(opts *Options, runtime string, toolkitRoot string) error {
func ValidateOptions(c *cli.Context, opts *Options, runtime string, toolkitRoot string, to *toolkit.Options) error {
// We set this option here to ensure that it is available in future calls.
opts.RuntimeDir = toolkitRoot
if !c.IsSet("enable-cdi-in-runtime") {
opts.EnableCDI = to.CDI.Enabled
}
// Apply the runtime-specific config changes.
switch runtime {
case containerd.Name:

View File

@ -48,6 +48,14 @@ const (
toolkitPidFilename = "toolkit.pid"
)
type cdiOptions struct {
Enabled bool
outputDir string
kind string
vendor string
class string
}
type Options struct {
DriverRoot string
DevRoot string
@ -67,11 +75,8 @@ type Options struct {
ContainerCLIDebug string
cdiEnabled bool
cdiOutputDir string
cdiKind string
cdiVendor string
cdiClass string
// CDI stores the CDI options for the toolkit.
CDI cdiOptions
createDeviceNodes cli.StringSlice
@ -174,21 +179,21 @@ func Flags(opts *Options) []cli.Flag {
Name: "cdi-enabled",
Aliases: []string{"enable-cdi"},
Usage: "enable the generation of a CDI specification",
Destination: &opts.cdiEnabled,
Destination: &opts.CDI.Enabled,
EnvVars: []string{"CDI_ENABLED", "ENABLE_CDI"},
},
&cli.StringFlag{
Name: "cdi-output-dir",
Usage: "the directory where the CDI output files are to be written. If this is set to '', no CDI specification is generated.",
Value: "/var/run/cdi",
Destination: &opts.cdiOutputDir,
Destination: &opts.CDI.outputDir,
EnvVars: []string{"CDI_OUTPUT_DIR"},
},
&cli.StringFlag{
Name: "cdi-kind",
Usage: "the vendor string to use for the generated CDI specification",
Value: "management.nvidia.com/gpu",
Destination: &opts.cdiKind,
Destination: &opts.CDI.kind,
EnvVars: []string{"CDI_KIND"},
},
&cli.BoolFlag{
@ -221,19 +226,19 @@ func ValidateOptions(opts *Options, toolkitRoot string) error {
return fmt.Errorf("invalid --toolkit-root option: %v", toolkitRoot)
}
vendor, class := parser.ParseQualifier(opts.cdiKind)
vendor, class := parser.ParseQualifier(opts.CDI.kind)
if err := parser.ValidateVendorName(vendor); err != nil {
return fmt.Errorf("invalid CDI vendor name: %v", err)
}
if err := parser.ValidateClassName(class); err != nil {
return fmt.Errorf("invalid CDI class name: %v", err)
}
opts.cdiVendor = vendor
opts.cdiClass = class
opts.CDI.vendor = vendor
opts.CDI.class = class
if opts.cdiEnabled && opts.cdiOutputDir == "" {
if opts.CDI.Enabled && opts.CDI.outputDir == "" {
log.Warning("Skipping CDI spec generation (no output directory specified)")
opts.cdiEnabled = false
opts.CDI.Enabled = false
}
isDisabled := false
@ -246,7 +251,7 @@ func ValidateOptions(opts *Options, toolkitRoot string) error {
break
}
}
if !opts.cdiEnabled && !isDisabled {
if !opts.CDI.Enabled && !isDisabled {
log.Info("disabling device node creation since --cdi-enabled=false")
isDisabled = true
}
@ -761,7 +766,7 @@ func createDeviceNodes(opts *Options) error {
// generateCDISpec generates a CDI spec for use in management containers
func generateCDISpec(opts *Options, nvidiaCDIHookPath string) error {
if !opts.cdiEnabled {
if !opts.CDI.Enabled {
return nil
}
log.Info("Generating CDI spec for management containers")
@ -770,8 +775,8 @@ func generateCDISpec(opts *Options, nvidiaCDIHookPath string) error {
nvcdi.WithDriverRoot(opts.DriverRootCtrPath),
nvcdi.WithDevRoot(opts.DevRootCtrPath),
nvcdi.WithNVIDIACDIHookPath(nvidiaCDIHookPath),
nvcdi.WithVendor(opts.cdiVendor),
nvcdi.WithClass(opts.cdiClass),
nvcdi.WithVendor(opts.CDI.vendor),
nvcdi.WithClass(opts.CDI.class),
)
if err != nil {
return fmt.Errorf("failed to create CDI library for management containers: %v", err)
@ -796,7 +801,7 @@ func generateCDISpec(opts *Options, nvidiaCDIHookPath string) error {
if err != nil {
return fmt.Errorf("failed to generate CDI name for management containers: %v", err)
}
err = spec.Save(filepath.Join(opts.cdiOutputDir, name))
err = spec.Save(filepath.Join(opts.CDI.outputDir, name))
if err != nil {
return fmt.Errorf("failed to save CDI spec for management containers: %v", err)
}