diff --git a/internal/config/config.go b/internal/config/config.go index 33b8ba4d..a6646846 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,6 +18,8 @@ package config import ( "bufio" + "errors" + "fmt" "os" "path/filepath" "strings" @@ -51,6 +53,8 @@ var ( NVIDIAContainerToolkitExecutable = "nvidia-container-toolkit" ) +var errInvalidConfig = errors.New("invalid config value") + // Config represents the contents of the config.toml file for the NVIDIA Container Toolkit // Note: This is currently duplicated by the HookConfig in cmd/nvidia-container-toolkit/hook_config.go type Config struct { @@ -127,7 +131,18 @@ func GetDefault() (*Config, error) { return &d, nil } -func getLdConfigPath() string { +// assertValid checks for a valid config. +func (c *Config) assertValid() error { + if !c.Features.AllowLDConfigFromContainer.IsEnabled() && !strings.HasPrefix(c.NVIDIAContainerCLIConfig.Ldconfig, "@") { + return fmt.Errorf("%w: nvidia-container-cli.ldconfig value %q is not host-relative (does not start with a '@')", errInvalidConfig, c.NVIDIAContainerCLIConfig.Ldconfig) + } + return nil +} + +// getLdConfigPath allows us to override this function for testing. +var getLdConfigPath = getLdConfigPathStub + +func getLdConfigPathStub() string { return NormalizeLDConfigPath("@/sbin/ldconfig") } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 0873ebd2..67f132f7 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -44,23 +44,21 @@ func TestGetConfigWithCustomConfig(t *testing.T) { func TestGetConfig(t *testing.T) { testCases := []struct { - description string - contents []string - expectedError error - inspectLdconfig bool - distIdsLike []string - expectedConfig *Config + description string + contents []string + expectedError error + distIdsLike []string + expectedConfig *Config }{ { - description: "empty config is default", - inspectLdconfig: true, + description: "empty config is default", expectedConfig: &Config{ AcceptEnvvarUnprivileged: true, SupportedDriverCapabilities: "compat32,compute,display,graphics,ngx,utility,video", NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "", LoadKmods: true, - Ldconfig: "WAS_CHECKED", + Ldconfig: "@/test/ld/config/path", }, NVIDIAContainerRuntimeConfig: RuntimeConfig{ DebugFilePath: "/dev/null", @@ -93,7 +91,7 @@ func TestGetConfig(t *testing.T) { "supported-driver-capabilities = \"compute,utility\"", "nvidia-container-cli.root = \"/bar/baz\"", "nvidia-container-cli.load-kmods = false", - "nvidia-container-cli.ldconfig = \"/foo/bar/ldconfig\"", + "nvidia-container-cli.ldconfig = \"@/foo/bar/ldconfig\"", "nvidia-container-cli.user = \"foo:bar\"", "nvidia-container-runtime.debug = \"/foo/bar\"", "nvidia-container-runtime.discover-mode = \"not-legacy\"", @@ -113,7 +111,7 @@ func TestGetConfig(t *testing.T) { NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "/bar/baz", LoadKmods: false, - Ldconfig: "/foo/bar/ldconfig", + Ldconfig: "@/foo/bar/ldconfig", User: "foo:bar", }, NVIDIAContainerRuntimeConfig: RuntimeConfig{ @@ -146,6 +144,53 @@ func TestGetConfig(t *testing.T) { }, }, }, + { + description: "feature allows ldconfig to be overridden", + contents: []string{ + "[nvidia-container-cli]", + "ldconfig = \"/foo/bar/ldconfig\"", + "[features]", + "allow-ldconfig-from-container = true", + }, + expectedConfig: &Config{ + AcceptEnvvarUnprivileged: true, + SupportedDriverCapabilities: "compat32,compute,display,graphics,ngx,utility,video", + NVIDIAContainerCLIConfig: ContainerCLIConfig{ + Ldconfig: "/foo/bar/ldconfig", + LoadKmods: true, + }, + NVIDIAContainerRuntimeConfig: RuntimeConfig{ + DebugFilePath: "/dev/null", + LogLevel: "info", + Runtimes: []string{"docker-runc", "runc", "crun"}, + Mode: "auto", + Modes: modesConfig{ + CSV: csvModeConfig{ + MountSpecPath: "/etc/nvidia-container-runtime/host-files-for-container.d", + }, + CDI: cdiModeConfig{ + DefaultKind: "nvidia.com/gpu", + AnnotationPrefixes: []string{ + "cdi.k8s.io/", + }, + SpecDirs: []string{ + "/etc/cdi", + "/var/run/cdi", + }, + }, + }, + }, + NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{ + Path: "nvidia-container-runtime-hook", + }, + NVIDIACTKConfig: CTKConfig{ + Path: "nvidia-ctk", + }, + Features: features{ + AllowLDConfigFromContainer: ptr(feature(true)), + }, + }, + }, { description: "config options set in section", contents: []string{ @@ -154,7 +199,7 @@ func TestGetConfig(t *testing.T) { "[nvidia-container-cli]", "root = \"/bar/baz\"", "load-kmods = false", - "ldconfig = \"/foo/bar/ldconfig\"", + "ldconfig = \"@/foo/bar/ldconfig\"", "user = \"foo:bar\"", "[nvidia-container-runtime]", "debug = \"/foo/bar\"", @@ -179,7 +224,7 @@ func TestGetConfig(t *testing.T) { NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "/bar/baz", LoadKmods: false, - Ldconfig: "/foo/bar/ldconfig", + Ldconfig: "@/foo/bar/ldconfig", User: "foo:bar", }, NVIDIAContainerRuntimeConfig: RuntimeConfig{ @@ -213,16 +258,15 @@ func TestGetConfig(t *testing.T) { }, }, { - description: "suse config", - distIdsLike: []string{"suse", "opensuse"}, - inspectLdconfig: true, + description: "suse config", + distIdsLike: []string{"suse", "opensuse"}, expectedConfig: &Config{ AcceptEnvvarUnprivileged: true, SupportedDriverCapabilities: "compat32,compute,display,graphics,ngx,utility,video", NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "", LoadKmods: true, - Ldconfig: "WAS_CHECKED", + Ldconfig: "@/test/ld/config/path", User: "root:video", }, NVIDIAContainerRuntimeConfig: RuntimeConfig{ @@ -250,9 +294,8 @@ func TestGetConfig(t *testing.T) { }, }, { - description: "suse config overrides user", - distIdsLike: []string{"suse", "opensuse"}, - inspectLdconfig: true, + description: "suse config overrides user", + distIdsLike: []string{"suse", "opensuse"}, contents: []string{ "nvidia-container-cli.user = \"foo:bar\"", }, @@ -262,7 +305,7 @@ func TestGetConfig(t *testing.T) { NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "", LoadKmods: true, - Ldconfig: "WAS_CHECKED", + Ldconfig: "@/test/ld/config/path", User: "foo:bar", }, NVIDIAContainerRuntimeConfig: RuntimeConfig{ @@ -293,6 +336,7 @@ func TestGetConfig(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + defer setGetLdConfigPathForTest()() defer setGetDistIDLikeForTest(tc.distIdsLike)() reader := strings.NewReader(strings.Join(tc.contents, "\n")) @@ -305,21 +349,63 @@ func TestGetConfig(t *testing.T) { cfg, err := tomlCfg.Config() require.NoError(t, err) - // We first handle the ldconfig path since this is currently system-dependent. - if tc.inspectLdconfig { - ldconfig := cfg.NVIDIAContainerCLIConfig.Ldconfig - require.True(t, strings.HasPrefix(ldconfig, "@/sbin/ldconfig")) - remaining := strings.TrimPrefix(ldconfig, "@/sbin/ldconfig") - require.True(t, remaining == ".real" || remaining == "") - - cfg.NVIDIAContainerCLIConfig.Ldconfig = "WAS_CHECKED" - } - require.EqualValues(t, tc.expectedConfig, cfg) }) } } +func TestAssertValid(t *testing.T) { + defer setGetLdConfigPathForTest()() + + testCases := []struct { + description string + config *Config + expectedError error + }{ + { + description: "default is valid", + config: func() *Config { + config, _ := GetDefault() + return config + }(), + }, + { + description: "alternative host ldconfig path is valid", + config: &Config{ + NVIDIAContainerCLIConfig: ContainerCLIConfig{ + Ldconfig: "@/some/host/path", + }, + }, + }, + { + description: "non-host path is invalid", + config: &Config{ + NVIDIAContainerCLIConfig: ContainerCLIConfig{ + Ldconfig: "/non/host/path", + }, + }, + expectedError: errInvalidConfig, + }, + { + description: "feature flag allows non-host path", + config: &Config{ + NVIDIAContainerCLIConfig: ContainerCLIConfig{ + Ldconfig: "/non/host/path", + }, + Features: features{ + AllowLDConfigFromContainer: ptr(feature(true)), + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + require.ErrorIs(t, tc.config.assertValid(), tc.expectedError) + }) + } +} + // setGetDistIDsLikeForTest overrides the distribution IDs that would normally be read from the /etc/os-release file. func setGetDistIDLikeForTest(ids []string) func() { if ids == nil { @@ -335,3 +421,18 @@ func setGetDistIDLikeForTest(ids []string) func() { getDistIDLike = original } } + +// prt returns a reference to whatever type is passed into it +func ptr[T any](x T) *T { + return &x +} + +func setGetLdConfigPathForTest() func() { + previous := getLdConfigPath + getLdConfigPath = func() string { + return "@/test/ld/config/path" + } + return func() { + getLdConfigPath = previous + } +} diff --git a/internal/config/features.go b/internal/config/features.go index dfc6b165..240f6f87 100644 --- a/internal/config/features.go +++ b/internal/config/features.go @@ -19,10 +19,11 @@ package config type featureName string const ( - FeatureGDS = featureName("gds") - FeatureMOFED = featureName("mofed") - FeatureNVSWITCH = featureName("nvswitch") - FeatureGDRCopy = featureName("gdrcopy") + FeatureGDS = featureName("gds") + FeatureMOFED = featureName("mofed") + FeatureNVSWITCH = featureName("nvswitch") + FeatureGDRCopy = featureName("gdrcopy") + FeatureAllowLDConfigFromContainer = featureName("allow-ldconfig-from-container") ) // features specifies a set of named features. @@ -31,53 +32,58 @@ type features struct { MOFED *feature `toml:"mofed,omitempty"` NVSWITCH *feature `toml:"nvswitch,omitempty"` GDRCopy *feature `toml:"gdrcopy,omitempty"` + // AllowLDConfigFromContainer allows non-host ldconfig paths to be used. + // If this feature flag is not set to 'true' only host-rooted config paths + // (i.e. paths starting with an '@' are considered valid) + AllowLDConfigFromContainer *feature `toml:"allow-ldconfig-from-container,omitempty"` } type feature bool -// IsEnabled checks whether a specified named feature is enabled. +// IsEnabledInEnvironment checks whether a specified named feature is enabled. // An optional list of environments to check for feature-specific environment // variables can also be supplied. -func (fs features) IsEnabled(n featureName, in ...getenver) bool { - featureEnvvars := map[featureName]string{ - FeatureGDS: "NVIDIA_GDS", - FeatureMOFED: "NVIDIA_MOFED", - FeatureNVSWITCH: "NVIDIA_NVSWITCH", - FeatureGDRCopy: "NVIDIA_GDRCOPY", - } - - envvar := featureEnvvars[n] +func (fs features) IsEnabledInEnvironment(n featureName, in ...getenver) bool { switch n { + // Features with envvar overrides case FeatureGDS: - return fs.GDS.isEnabled(envvar, in...) + return fs.GDS.isEnabledWithEnvvarOverride("NVIDIA_GDS", in...) case FeatureMOFED: - return fs.MOFED.isEnabled(envvar, in...) + return fs.MOFED.isEnabledWithEnvvarOverride("NVIDIA_MOFED", in...) case FeatureNVSWITCH: - return fs.NVSWITCH.isEnabled(envvar, in...) + return fs.NVSWITCH.isEnabledWithEnvvarOverride("NVIDIA_NVSWITCH", in...) case FeatureGDRCopy: - return fs.GDRCopy.isEnabled(envvar, in...) + return fs.GDRCopy.isEnabledWithEnvvarOverride("NVIDIA_GDRCOPY", in...) + // Features without envvar overrides + case FeatureAllowLDConfigFromContainer: + return fs.AllowLDConfigFromContainer.IsEnabled() default: return false } } -// isEnabled checks whether a feature is enabled. -// If the enabled value is explicitly set, this is returned, otherwise the -// associated envvar is checked in the specified getenver for the string "enabled" -// A CUDA container / image can be passed here. -func (f *feature) isEnabled(envvar string, ins ...getenver) bool { +// IsEnabled checks whether a feature is enabled. +func (f *feature) IsEnabled() bool { if f != nil { return bool(*f) } - if envvar == "" { - return false - } - for _, in := range ins { - if in.Getenv(envvar) == "enabled" { - return true + return false +} + +// isEnabledWithEnvvarOverride checks whether a feature is enabled and allows an envvar to overide the feature. +// If the enabled value is explicitly set, this is returned, otherwise the +// associated envvar is checked in the specified getenver for the string "enabled" +// A CUDA container / image can be passed here. +func (f *feature) isEnabledWithEnvvarOverride(envvar string, ins ...getenver) bool { + if envvar != "" { + for _, in := range ins { + if in.Getenv(envvar) == "enabled" { + return true + } } } - return false + + return f.IsEnabled() } type getenver interface { diff --git a/internal/config/toml.go b/internal/config/toml.go index a1d37428..4df2a099 100644 --- a/internal/config/toml.go +++ b/internal/config/toml.go @@ -108,6 +108,19 @@ func loadConfigTomlFrom(reader io.Reader) (*Toml, error) { // Config returns the typed config associated with the toml tree. func (t *Toml) Config() (*Config, error) { + cfg, err := t.configNoOverrides() + if err != nil { + return nil, err + } + if err := cfg.assertValid(); err != nil { + return nil, err + } + return cfg, nil +} + +// configNoOverrides returns the typed config associated with the toml tree. +// This config does not include feature-specific overrides. +func (t *Toml) configNoOverrides() (*Config, error) { cfg, err := GetDefault() if err != nil { return nil, err diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go index e017db15..f7c649f7 100644 --- a/internal/config/toml_test.go +++ b/internal/config/toml_test.go @@ -198,9 +198,12 @@ func TestTomlContents(t *testing.T) { } func TestConfigFromToml(t *testing.T) { + defer setGetLdConfigPathForTest()() + testCases := []struct { description string contents map[string]interface{} + expectedError error expectedConfig *Config }{ { @@ -226,13 +229,39 @@ func TestConfigFromToml(t *testing.T) { return c }(), }, + { + description: "invalid ldconfig value raises error", + contents: map[string]interface{}{ + "nvidia-container-cli": map[string]interface{}{ + "ldconfig": "/some/ldconfig/path", + }, + }, + expectedError: errInvalidConfig, + }, + { + description: "feature allows ldconfig override", + contents: map[string]interface{}{ + "nvidia-container-cli": map[string]interface{}{ + "ldconfig": "/some/ldconfig/path", + }, + "features": map[string]interface{}{ + "allow-ldconfig-from-container": true, + }, + }, + expectedConfig: func() *Config { + c, _ := GetDefault() + c.NVIDIAContainerCLIConfig.Ldconfig = "/some/ldconfig/path" + c.Features.AllowLDConfigFromContainer = ptr(feature(true)) + return c + }(), + }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { tomlCfg := fromMap(tc.contents) config, err := tomlCfg.Config() - require.NoError(t, err) + require.ErrorIs(t, err, tc.expectedError) require.EqualValues(t, tc.expectedConfig, config) }) } diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index 5bed3eaf..70322b35 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -46,7 +46,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image driverRoot := cfg.NVIDIAContainerCLIConfig.Root devRoot := cfg.NVIDIAContainerCLIConfig.Root - if cfg.Features.IsEnabled(config.FeatureGDS, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureGDS, image) { d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err) @@ -54,7 +54,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if cfg.Features.IsEnabled(config.FeatureMOFED, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureMOFED, image) { d, err := discover.NewMOFEDDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err) @@ -62,7 +62,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureNVSWITCH, image) { d, err := discover.NewNvSwitchDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err) @@ -70,7 +70,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) { + if cfg.Features.IsEnabledInEnvironment(config.FeatureGDRCopy, image) { d, err := discover.NewGDRCopyDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)