diff --git a/cmd/nvidia-ctk/config/config.go b/cmd/nvidia-ctk/config/config.go index 3f10f727..3c4555f4 100644 --- a/cmd/nvidia-ctk/config/config.go +++ b/cmd/nvidia-ctk/config/config.go @@ -109,7 +109,11 @@ func run(c *cli.Context, opts *options) error { if err != nil { return fmt.Errorf("invalid --set option %v: %w", set, err) } - cfgToml.Set(key, value) + if value == nil { + _ = cfgToml.Delete(key) + } else { + cfgToml.Set(key, value) + } } if err := opts.EnsureOutputFolder(); err != nil { @@ -146,20 +150,25 @@ func setFlagToKeyValue(setFlag string) (string, interface{}, error) { kind := field.Kind() if len(setParts) != 2 { - if kind == reflect.Bool { + if kind == reflect.Bool || (kind == reflect.Pointer && field.Elem().Kind() == reflect.Bool) { return key, true, nil } return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag) } value := setParts[1] + if kind == reflect.Pointer && value != "nil" { + kind = field.Elem().Kind() + } switch kind { + case reflect.Pointer: + return key, nil, nil case reflect.Bool: b, err := strconv.ParseBool(value) if err != nil { return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err) } - return key, b, err + return key, b, nil case reflect.String: return key, value, nil case reflect.Slice: @@ -201,7 +210,7 @@ func getStruct(current reflect.Type, paths ...string) (reflect.StructField, erro if !ok { continue } - if v != tomlField { + if strings.SplitN(v, ",", 2)[0] != tomlField { continue } if len(paths) == 1 { diff --git a/internal/config/config.go b/internal/config/config.go index b94184b4..f3114fba 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -63,6 +63,9 @@ type Config struct { NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"` NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"` NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"` + + // Features allows for finer control over optional features. + Features features `toml:"features,omitempty"` } // GetConfigFilePath returns the path to the config file for the configured system diff --git a/internal/config/features.go b/internal/config/features.go new file mode 100644 index 00000000..dfc6b165 --- /dev/null +++ b/internal/config/features.go @@ -0,0 +1,85 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package config + +type featureName string + +const ( + FeatureGDS = featureName("gds") + FeatureMOFED = featureName("mofed") + FeatureNVSWITCH = featureName("nvswitch") + FeatureGDRCopy = featureName("gdrcopy") +) + +// features specifies a set of named features. +type features struct { + GDS *feature `toml:"gds,omitempty"` + MOFED *feature `toml:"mofed,omitempty"` + NVSWITCH *feature `toml:"nvswitch,omitempty"` + GDRCopy *feature `toml:"gdrcopy,omitempty"` +} + +type feature bool + +// IsEnabled checks whether a specified named feature is enabled. +// An optional list of environments to check for feature-specific environment +// variables can also be supplied. +func (fs features) IsEnabled(n featureName, in ...getenver) bool { + featureEnvvars := map[featureName]string{ + FeatureGDS: "NVIDIA_GDS", + FeatureMOFED: "NVIDIA_MOFED", + FeatureNVSWITCH: "NVIDIA_NVSWITCH", + FeatureGDRCopy: "NVIDIA_GDRCOPY", + } + + envvar := featureEnvvars[n] + switch n { + case FeatureGDS: + return fs.GDS.isEnabled(envvar, in...) + case FeatureMOFED: + return fs.MOFED.isEnabled(envvar, in...) + case FeatureNVSWITCH: + return fs.NVSWITCH.isEnabled(envvar, in...) + case FeatureGDRCopy: + return fs.GDRCopy.isEnabled(envvar, in...) + default: + return false + } +} + +// isEnabled checks whether a feature is enabled. +// If the enabled value is explicitly set, this is returned, otherwise the +// associated envvar is checked in the specified getenver for the string "enabled" +// A CUDA container / image can be passed here. +func (f *feature) isEnabled(envvar string, ins ...getenver) bool { + if f != nil { + return bool(*f) + } + if envvar == "" { + return false + } + for _, in := range ins { + if in.Getenv(envvar) == "enabled" { + return true + } + } + return false +} + +type getenver interface { + Getenv(string) string +} diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index 1af023ab..5bed3eaf 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -26,13 +26,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) -const ( - nvidiaGDSEnvvar = "NVIDIA_GDS" - nvidiaMOFEDEnvvar = "NVIDIA_MOFED" - nvidiaNVSWITCHEnvvar = "NVIDIA_NVSWITCH" - nvidiaGDRCOPYEnvvar = "NVIDIA_GDRCOPY" -) - // NewFeatureGatedModifier creates the modifiers for optional features. // These include: // @@ -53,7 +46,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image driverRoot := cfg.NVIDIAContainerCLIConfig.Root devRoot := cfg.NVIDIAContainerCLIConfig.Root - if image.Getenv(nvidiaGDSEnvvar) == "enabled" { + if cfg.Features.IsEnabled(config.FeatureGDS, image) { d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err) @@ -61,7 +54,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if image.Getenv(nvidiaMOFEDEnvvar) == "enabled" { + if cfg.Features.IsEnabled(config.FeatureMOFED, image) { d, err := discover.NewMOFEDDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err) @@ -69,7 +62,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if image.Getenv(nvidiaNVSWITCHEnvvar) == "enabled" { + if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) { d, err := discover.NewNvSwitchDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err) @@ -77,7 +70,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image discoverers = append(discoverers, d) } - if image.Getenv(nvidiaGDRCOPYEnvvar) == "enabled" { + if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) { d, err := discover.NewGDRCopyDiscoverer(logger, devRoot) if err != nil { return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)