mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-28 23:17:24 +00:00
Merge pull request #362 from elezar/add-feature-flags
Add support for feature flags
This commit is contained in:
commit
413da20838
@ -109,8 +109,12 @@ func run(c *cli.Context, opts *options) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid --set option %v: %w", set, err)
|
return fmt.Errorf("invalid --set option %v: %w", set, err)
|
||||||
}
|
}
|
||||||
|
if value == nil {
|
||||||
|
_ = cfgToml.Delete(key)
|
||||||
|
} else {
|
||||||
cfgToml.Set(key, value)
|
cfgToml.Set(key, value)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := opts.EnsureOutputFolder(); err != nil {
|
if err := opts.EnsureOutputFolder(); err != nil {
|
||||||
return fmt.Errorf("failed to create output directory: %v", err)
|
return fmt.Errorf("failed to create output directory: %v", err)
|
||||||
@ -146,20 +150,25 @@ func setFlagToKeyValue(setFlag string) (string, interface{}, error) {
|
|||||||
|
|
||||||
kind := field.Kind()
|
kind := field.Kind()
|
||||||
if len(setParts) != 2 {
|
if len(setParts) != 2 {
|
||||||
if kind == reflect.Bool {
|
if kind == reflect.Bool || (kind == reflect.Pointer && field.Elem().Kind() == reflect.Bool) {
|
||||||
return key, true, nil
|
return key, true, nil
|
||||||
}
|
}
|
||||||
return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag)
|
return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag)
|
||||||
}
|
}
|
||||||
|
|
||||||
value := setParts[1]
|
value := setParts[1]
|
||||||
|
if kind == reflect.Pointer && value != "nil" {
|
||||||
|
kind = field.Elem().Kind()
|
||||||
|
}
|
||||||
switch kind {
|
switch kind {
|
||||||
|
case reflect.Pointer:
|
||||||
|
return key, nil, nil
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
b, err := strconv.ParseBool(value)
|
b, err := strconv.ParseBool(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err)
|
return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err)
|
||||||
}
|
}
|
||||||
return key, b, err
|
return key, b, nil
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
return key, value, nil
|
return key, value, nil
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
@ -201,7 +210,7 @@ func getStruct(current reflect.Type, paths ...string) (reflect.StructField, erro
|
|||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if v != tomlField {
|
if strings.SplitN(v, ",", 2)[0] != tomlField {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if len(paths) == 1 {
|
if len(paths) == 1 {
|
||||||
|
@ -63,6 +63,9 @@ type Config struct {
|
|||||||
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
|
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
|
||||||
NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"`
|
NVIDIAContainerRuntimeConfig RuntimeConfig `toml:"nvidia-container-runtime"`
|
||||||
NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"`
|
NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"`
|
||||||
|
|
||||||
|
// Features allows for finer control over optional features.
|
||||||
|
Features features `toml:"features,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfigFilePath returns the path to the config file for the configured system
|
// GetConfigFilePath returns the path to the config file for the configured system
|
||||||
|
85
internal/config/features.go
Normal file
85
internal/config/features.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
/**
|
||||||
|
# Copyright 2024 NVIDIA CORPORATION
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
**/
|
||||||
|
|
||||||
|
package config
|
||||||
|
|
||||||
|
type featureName string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FeatureGDS = featureName("gds")
|
||||||
|
FeatureMOFED = featureName("mofed")
|
||||||
|
FeatureNVSWITCH = featureName("nvswitch")
|
||||||
|
FeatureGDRCopy = featureName("gdrcopy")
|
||||||
|
)
|
||||||
|
|
||||||
|
// features specifies a set of named features.
|
||||||
|
type features struct {
|
||||||
|
GDS *feature `toml:"gds,omitempty"`
|
||||||
|
MOFED *feature `toml:"mofed,omitempty"`
|
||||||
|
NVSWITCH *feature `toml:"nvswitch,omitempty"`
|
||||||
|
GDRCopy *feature `toml:"gdrcopy,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type feature bool
|
||||||
|
|
||||||
|
// IsEnabled checks whether a specified named feature is enabled.
|
||||||
|
// An optional list of environments to check for feature-specific environment
|
||||||
|
// variables can also be supplied.
|
||||||
|
func (fs features) IsEnabled(n featureName, in ...getenver) bool {
|
||||||
|
featureEnvvars := map[featureName]string{
|
||||||
|
FeatureGDS: "NVIDIA_GDS",
|
||||||
|
FeatureMOFED: "NVIDIA_MOFED",
|
||||||
|
FeatureNVSWITCH: "NVIDIA_NVSWITCH",
|
||||||
|
FeatureGDRCopy: "NVIDIA_GDRCOPY",
|
||||||
|
}
|
||||||
|
|
||||||
|
envvar := featureEnvvars[n]
|
||||||
|
switch n {
|
||||||
|
case FeatureGDS:
|
||||||
|
return fs.GDS.isEnabled(envvar, in...)
|
||||||
|
case FeatureMOFED:
|
||||||
|
return fs.MOFED.isEnabled(envvar, in...)
|
||||||
|
case FeatureNVSWITCH:
|
||||||
|
return fs.NVSWITCH.isEnabled(envvar, in...)
|
||||||
|
case FeatureGDRCopy:
|
||||||
|
return fs.GDRCopy.isEnabled(envvar, in...)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEnabled checks whether a feature is enabled.
|
||||||
|
// If the enabled value is explicitly set, this is returned, otherwise the
|
||||||
|
// associated envvar is checked in the specified getenver for the string "enabled"
|
||||||
|
// A CUDA container / image can be passed here.
|
||||||
|
func (f *feature) isEnabled(envvar string, ins ...getenver) bool {
|
||||||
|
if f != nil {
|
||||||
|
return bool(*f)
|
||||||
|
}
|
||||||
|
if envvar == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, in := range ins {
|
||||||
|
if in.Getenv(envvar) == "enabled" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type getenver interface {
|
||||||
|
Getenv(string) string
|
||||||
|
}
|
@ -26,13 +26,6 @@ import (
|
|||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
nvidiaGDSEnvvar = "NVIDIA_GDS"
|
|
||||||
nvidiaMOFEDEnvvar = "NVIDIA_MOFED"
|
|
||||||
nvidiaNVSWITCHEnvvar = "NVIDIA_NVSWITCH"
|
|
||||||
nvidiaGDRCOPYEnvvar = "NVIDIA_GDRCOPY"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewFeatureGatedModifier creates the modifiers for optional features.
|
// NewFeatureGatedModifier creates the modifiers for optional features.
|
||||||
// These include:
|
// These include:
|
||||||
//
|
//
|
||||||
@ -53,7 +46,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
|
|||||||
driverRoot := cfg.NVIDIAContainerCLIConfig.Root
|
driverRoot := cfg.NVIDIAContainerCLIConfig.Root
|
||||||
devRoot := cfg.NVIDIAContainerCLIConfig.Root
|
devRoot := cfg.NVIDIAContainerCLIConfig.Root
|
||||||
|
|
||||||
if image.Getenv(nvidiaGDSEnvvar) == "enabled" {
|
if cfg.Features.IsEnabled(config.FeatureGDS, image) {
|
||||||
d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot)
|
d, err := discover.NewGDSDiscoverer(logger, driverRoot, devRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
|
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
|
||||||
@ -61,7 +54,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
|
|||||||
discoverers = append(discoverers, d)
|
discoverers = append(discoverers, d)
|
||||||
}
|
}
|
||||||
|
|
||||||
if image.Getenv(nvidiaMOFEDEnvvar) == "enabled" {
|
if cfg.Features.IsEnabled(config.FeatureMOFED, image) {
|
||||||
d, err := discover.NewMOFEDDiscoverer(logger, devRoot)
|
d, err := discover.NewMOFEDDiscoverer(logger, devRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
|
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
|
||||||
@ -69,7 +62,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
|
|||||||
discoverers = append(discoverers, d)
|
discoverers = append(discoverers, d)
|
||||||
}
|
}
|
||||||
|
|
||||||
if image.Getenv(nvidiaNVSWITCHEnvvar) == "enabled" {
|
if cfg.Features.IsEnabled(config.FeatureNVSWITCH, image) {
|
||||||
d, err := discover.NewNvSwitchDiscoverer(logger, devRoot)
|
d, err := discover.NewNvSwitchDiscoverer(logger, devRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
|
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
|
||||||
@ -77,7 +70,7 @@ func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image
|
|||||||
discoverers = append(discoverers, d)
|
discoverers = append(discoverers, d)
|
||||||
}
|
}
|
||||||
|
|
||||||
if image.Getenv(nvidiaGDRCOPYEnvvar) == "enabled" {
|
if cfg.Features.IsEnabled(config.FeatureGDRCopy, image) {
|
||||||
d, err := discover.NewGDRCopyDiscoverer(logger, devRoot)
|
d, err := discover.NewGDRCopyDiscoverer(logger, devRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)
|
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)
|
||||||
|
Loading…
Reference in New Issue
Block a user