diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index faaf0b51..f53a649a 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -89,6 +89,12 @@ func doPrestart() { rootfs := getRootfsPath(container) args := []string{getCLIPath(cli)} + + // Only include the nvidia-persistenced socket if it is explicitly enabled. + if !hook.Features.IncludePersistencedSocket.IsEnabled() { + args = append(args, "--no-persistenced") + } + if cli.Root != "" { args = append(args, fmt.Sprintf("--root=%s", cli.Root)) } diff --git a/internal/config/features.go b/internal/config/features.go index dfc6b165..cbbecf3c 100644 --- a/internal/config/features.go +++ b/internal/config/features.go @@ -19,10 +19,11 @@ package config type featureName string const ( - FeatureGDS = featureName("gds") - FeatureMOFED = featureName("mofed") - FeatureNVSWITCH = featureName("nvswitch") - FeatureGDRCopy = featureName("gdrcopy") + FeatureGDS = featureName("gds") + FeatureMOFED = featureName("mofed") + FeatureNVSWITCH = featureName("nvswitch") + FeatureGDRCopy = featureName("gdrcopy") + FeatureIncludePersistencedSocket = featureName("include-persistenced-socket") ) // features specifies a set of named features. @@ -31,53 +32,57 @@ type features struct { MOFED *feature `toml:"mofed,omitempty"` NVSWITCH *feature `toml:"nvswitch,omitempty"` GDRCopy *feature `toml:"gdrcopy,omitempty"` + // IncludePersistencedSocket enables the injection of the nvidia-persistenced + // socket into containers. + IncludePersistencedSocket *feature `toml:"include-persistenced-socket,omitempty"` } type feature bool -// IsEnabled checks whether a specified named feature is enabled. +// IsEnabledInEnvironment checks whether a specified named feature is enabled. // An optional list of environments to check for feature-specific environment // variables can also be supplied. -func (fs features) IsEnabled(n featureName, in ...getenver) bool { - featureEnvvars := map[featureName]string{ - FeatureGDS: "NVIDIA_GDS", - FeatureMOFED: "NVIDIA_MOFED", - FeatureNVSWITCH: "NVIDIA_NVSWITCH", - FeatureGDRCopy: "NVIDIA_GDRCOPY", - } - - envvar := featureEnvvars[n] +func (fs features) IsEnabledInEnvironment(n featureName, in ...getenver) bool { switch n { + // Features with envvar overrides case FeatureGDS: - return fs.GDS.isEnabled(envvar, in...) + return fs.GDS.isEnabledWithEnvvarOverride("NVIDIA_GDS", in...) case FeatureMOFED: - return fs.MOFED.isEnabled(envvar, in...) + return fs.MOFED.isEnabledWithEnvvarOverride("NVIDIA_MOFED", in...) case FeatureNVSWITCH: - return fs.NVSWITCH.isEnabled(envvar, in...) + return fs.NVSWITCH.isEnabledWithEnvvarOverride("NVIDIA_NVSWITCH", in...) case FeatureGDRCopy: - return fs.GDRCopy.isEnabled(envvar, in...) + return fs.GDRCopy.isEnabledWithEnvvarOverride("NVIDIA_GDRCOPY", in...) + // Features without envvar overrides + case FeatureIncludePersistencedSocket: + return fs.IncludePersistencedSocket.IsEnabled() default: return false } } -// isEnabled checks whether a feature is enabled. -// If the enabled value is explicitly set, this is returned, otherwise the -// associated envvar is checked in the specified getenver for the string "enabled" -// A CUDA container / image can be passed here. -func (f *feature) isEnabled(envvar string, ins ...getenver) bool { +// IsEnabled checks whether a feature is enabled. +func (f *feature) IsEnabled() bool { if f != nil { return bool(*f) } - if envvar == "" { - return false - } - for _, in := range ins { - if in.Getenv(envvar) == "enabled" { - return true + return false +} + +// isEnabledWithEnvvarOverride checks whether a feature is enabled and allows an envvar to overide the feature. +// If the enabled value is explicitly set, this is returned, otherwise the +// associated envvar is checked in the specified getenver for the string "enabled" +// A CUDA container / image can be passed here. +func (f *feature) isEnabledWithEnvvarOverride(envvar string, ins ...getenver) bool { + if envvar != "" { + for _, in := range ins { + if in.Getenv(envvar) == "enabled" { + return true + } } } - return false + + return f.IsEnabled() } type getenver interface { diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index 5bed3eaf..70322b35 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -46,7 +46,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image driverRoot := cfg.NVIDIAContainerCLIConfig.Root devRoot := cfg.NVIDIAContainerCLIConfig.Root - if cfg.Features.IsEnabled(config.FeatureGDS, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureGDS, image) { d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err) @@ -54,7 +54,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if cfg.Features.IsEnabled(config.FeatureMOFED, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureMOFED, image) { d, err := discover.NewMOFEDDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err) @@ -62,7 +62,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureNVSWITCH, image) { d, err := discover.NewNvSwitchDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err) @@ -70,7 +70,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureGDRCopy, image) { d, err := discover.NewGDRCopyDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)