Merge pull request #362 from elezar/add-feature-flags

Add support for feature flags
This commit is contained in:
Evan Lezar 2024-04-03 12:04:18 +02:00 committed by GitHub
commit 413da20838
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 105 additions and 15 deletions

View File

@ -109,7 +109,11 @@ func run(c *cli.Context, opts *options) error {
if err != nil { if err != nil {
return fmt.Errorf("invalid --set option %v: %w", set, err) return fmt.Errorf("invalid --set option %v: %w", set, err)
} }
cfgToml.Set(key, value) if value == nil {
_ = cfgToml.Delete(key)
} else {
cfgToml.Set(key, value)
}
} }
if err := opts.EnsureOutputFolder(); err != nil { if err := opts.EnsureOutputFolder(); err != nil {
@ -146,20 +150,25 @@ func setFlagToKeyValue(setFlag string) (string, interface{}, error) {
kind := field.Kind() kind := field.Kind()
if len(setParts) != 2 { if len(setParts) != 2 {
if kind == reflect.Bool { if kind == reflect.Bool || (kind == reflect.Pointer && field.Elem().Kind() == reflect.Bool) {
return key, true, nil return key, true, nil
} }
return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag) return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag)
} }
value := setParts[1] value := setParts[1]
if kind == reflect.Pointer && value != "nil" {
kind = field.Elem().Kind()
}
switch kind { switch kind {
case reflect.Pointer:
return key, nil, nil
case reflect.Bool: case reflect.Bool:
b, err := strconv.ParseBool(value) b, err := strconv.ParseBool(value)
if err != nil { if err != nil {
return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err) return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err)
} }
return key, b, err return key, b, nil
case reflect.String: case reflect.String:
return key, value, nil return key, value, nil
case reflect.Slice: case reflect.Slice:
@ -201,7 +210,7 @@ func getStruct(current reflect.Type, paths ...string) (reflect.StructField, erro
if !ok { if !ok {
continue continue
} }
if v != tomlField { if strings.SplitN(v, ",", 2)[0] != tomlField {
continue continue
} }
if len(paths) == 1 { if len(paths) == 1 {

View File

@ -63,6 +63,9 @@ type Config struct {
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"` NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"` NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"`
NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"` NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"`
// Features allows for finer control over optional features.
Features features `toml:"features,omitempty"`
} }
// GetConfigFilePath returns the path to the config file for the configured system // GetConfigFilePath returns the path to the config file for the configured system

View File

@ -0,0 +1,85 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package config
type featureName string
const (
FeatureGDS = featureName("gds")
FeatureMOFED = featureName("mofed")
FeatureNVSWITCH = featureName("nvswitch")
FeatureGDRCopy = featureName("gdrcopy")
)
// features specifies a set of named features.
type features struct {
GDS *feature `toml:"gds,omitempty"`
MOFED *feature `toml:"mofed,omitempty"`
NVSWITCH *feature `toml:"nvswitch,omitempty"`
GDRCopy *feature `toml:"gdrcopy,omitempty"`
}
type feature bool
// IsEnabled checks whether a specified named feature is enabled.
// An optional list of environments to check for feature-specific environment
// variables can also be supplied.
func (fs features) IsEnabled(n featureName, in ...getenver) bool {
featureEnvvars := map[featureName]string{
FeatureGDS: "NVIDIA_GDS",
FeatureMOFED: "NVIDIA_MOFED",
FeatureNVSWITCH: "NVIDIA_NVSWITCH",
FeatureGDRCopy: "NVIDIA_GDRCOPY",
}
envvar := featureEnvvars[n]
switch n {
case FeatureGDS:
return fs.GDS.isEnabled(envvar, in...)
case FeatureMOFED:
return fs.MOFED.isEnabled(envvar, in...)
case FeatureNVSWITCH:
return fs.NVSWITCH.isEnabled(envvar, in...)
case FeatureGDRCopy:
return fs.GDRCopy.isEnabled(envvar, in...)
default:
return false
}
}
// isEnabled checks whether a feature is enabled.
// If the enabled value is explicitly set, this is returned, otherwise the
// associated envvar is checked in the specified getenver for the string "enabled"
// A CUDA container / image can be passed here.
func (f *feature) isEnabled(envvar string, ins ...getenver) bool {
if f != nil {
return bool(*f)
}
if envvar == "" {
return false
}
for _, in := range ins {
if in.Getenv(envvar) == "enabled" {
return true
}
}
return false
}
type getenver interface {
Getenv(string) string
}

View File

@ -26,13 +26,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
) )
const (
nvidiaGDSEnvvar = "NVIDIA_GDS"
nvidiaMOFEDEnvvar = "NVIDIA_MOFED"
nvidiaNVSWITCHEnvvar = "NVIDIA_NVSWITCH"
nvidiaGDRCOPYEnvvar = "NVIDIA_GDRCOPY"
)
// NewFeatureGatedModifier creates the modifiers for optional features. // NewFeatureGatedModifier creates the modifiers for optional features.
// These include: // These include:
// //
@ -53,7 +46,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
driverRoot := cfg.NVIDIAContainerCLIConfig.Root driverRoot := cfg.NVIDIAContainerCLIConfig.Root
devRoot := cfg.NVIDIAContainerCLIConfig.Root devRoot := cfg.NVIDIAContainerCLIConfig.Root
if image.Getenv(nvidiaGDSEnvvar) == "enabled" { if cfg.Features.IsEnabled(config.FeatureGDS, image) {
d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot) d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err) return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
@ -61,7 +54,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
discoverers = append(discoverers, d) discoverers = append(discoverers, d)
} }
if image.Getenv(nvidiaMOFEDEnvvar) == "enabled" { if cfg.Features.IsEnabled(config.FeatureMOFED, image) {
d, err := discover.NewMOFEDDiscoverer(logger, devRoot) d, err := discover.NewMOFEDDiscoverer(logger, devRoot)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err) return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
@ -69,7 +62,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
discoverers = append(discoverers, d) discoverers = append(discoverers, d)
} }
if image.Getenv(nvidiaNVSWITCHEnvvar) == "enabled" { if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) {
d, err := discover.NewNvSwitchDiscoverer(logger, devRoot) d, err := discover.NewNvSwitchDiscoverer(logger, devRoot)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err) return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
@ -77,7 +70,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
discoverers = append(discoverers, d) discoverers = append(discoverers, d)
} }
if image.Getenv(nvidiaGDRCOPYEnvvar) == "enabled" { if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) {
d, err := discover.NewGDRCopyDiscoverer(logger, devRoot) d, err := discover.NewGDRCopyDiscoverer(logger, devRoot)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err) return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)