mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-06-12 09:31:17 +00:00
Load settings from config.toml file during CDI generation
Some checks failed
Some checks failed
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
This commit is contained in:
parent
fdcd250362
commit
f882eb56da
@ -68,6 +68,14 @@ type options struct {
|
||||
nvmllib nvml.Interface
|
||||
}
|
||||
|
||||
// Add setter methods for csv.files and csv.ignorePatterns
|
||||
func (o *options) SetCSVFiles(files []string) {
|
||||
o.csv.files = *cli.NewStringSlice(files...)
|
||||
}
|
||||
func (o *options) SetCSVIgnorePatterns(patterns []string) {
|
||||
o.csv.ignorePatterns = *cli.NewStringSlice(patterns...)
|
||||
}
|
||||
|
||||
// NewCommand constructs a generate-cdi command with the specified logger
|
||||
func NewCommand(logger logger.Interface) *cli.Command {
|
||||
c := command{
|
||||
@ -192,6 +200,16 @@ func (m command) build() *cli.Command {
|
||||
}
|
||||
|
||||
func (m command) validateFlags(c *cli.Context, opts *options) error {
|
||||
// Load config file as base configuration
|
||||
cfg, err := config.GetConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Use centralized flag resolution (CLI > config file > default)
|
||||
config.ResolveCDIGenerateOptions(c, cfg, opts)
|
||||
|
||||
// Additional validation (format, mode, etc.) can remain here if needed
|
||||
opts.format = strings.ToLower(opts.format)
|
||||
switch opts.format {
|
||||
case spec.FormatJSON:
|
||||
|
@ -18,6 +18,7 @@ package generate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -32,6 +33,29 @@ import (
|
||||
)
|
||||
|
||||
func TestGenerateSpec(t *testing.T) {
|
||||
// Create a temporary directory for config
|
||||
tmpDir, err := os.MkdirTemp("", "nvidia-container-toolkit-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a temporary config file
|
||||
configContent := `
|
||||
[nvidia-container-runtime]
|
||||
mode = "nvml"
|
||||
[[nvidia-container-runtime.modes.cdi]]
|
||||
spec-dirs = ["/etc/cdi", "/usr/local/cdi"]
|
||||
[nvidia-container-runtime.modes.csv]
|
||||
mount-spec-path = "/etc/nvidia-container-runtime/host-files-for-container.d"
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
err = os.WriteFile(configPath, []byte(configContent), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set XDG_CONFIG_HOME to point to our temporary directory
|
||||
oldXDGConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", tmpDir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", oldXDGConfigHome)
|
||||
|
||||
t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true")
|
||||
moduleRoot, err := test.GetModuleRoot()
|
||||
require.NoError(t, err)
|
||||
@ -63,6 +87,13 @@ func TestGenerateSpec(t *testing.T) {
|
||||
class: "device",
|
||||
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
|
||||
driverRoot: driverRoot,
|
||||
csv: struct {
|
||||
files cli.StringSlice
|
||||
ignorePatterns cli.StringSlice
|
||||
}{
|
||||
files: *cli.NewStringSlice("/etc/nvidia-container-runtime/host-files-for-container.d"),
|
||||
ignorePatterns: *cli.NewStringSlice(),
|
||||
},
|
||||
},
|
||||
expectedSpec: `---
|
||||
cdiVersion: 0.5.0
|
||||
@ -140,6 +171,13 @@ containerEdits:
|
||||
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
|
||||
driverRoot: driverRoot,
|
||||
disabledHooks: valueOf(cli.NewStringSlice("enable-cuda-compat")),
|
||||
csv: struct {
|
||||
files cli.StringSlice
|
||||
ignorePatterns cli.StringSlice
|
||||
}{
|
||||
files: *cli.NewStringSlice("/etc/nvidia-container-runtime/host-files-for-container.d"),
|
||||
ignorePatterns: *cli.NewStringSlice(),
|
||||
},
|
||||
},
|
||||
expectedSpec: `---
|
||||
cdiVersion: 0.5.0
|
||||
@ -209,6 +247,13 @@ containerEdits:
|
||||
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
|
||||
driverRoot: driverRoot,
|
||||
disabledHooks: valueOf(cli.NewStringSlice("enable-cuda-compat", "update-ldcache")),
|
||||
csv: struct {
|
||||
files cli.StringSlice
|
||||
ignorePatterns cli.StringSlice
|
||||
}{
|
||||
files: *cli.NewStringSlice("/etc/nvidia-container-runtime/host-files-for-container.d"),
|
||||
ignorePatterns: *cli.NewStringSlice(),
|
||||
},
|
||||
},
|
||||
expectedSpec: `---
|
||||
cdiVersion: 0.5.0
|
||||
@ -269,6 +314,13 @@ containerEdits:
|
||||
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
|
||||
driverRoot: driverRoot,
|
||||
disabledHooks: valueOf(cli.NewStringSlice("all")),
|
||||
csv: struct {
|
||||
files cli.StringSlice
|
||||
ignorePatterns cli.StringSlice
|
||||
}{
|
||||
files: *cli.NewStringSlice("/etc/nvidia-container-runtime/host-files-for-container.d"),
|
||||
ignorePatterns: *cli.NewStringSlice(),
|
||||
},
|
||||
},
|
||||
expectedSpec: `---
|
||||
cdiVersion: 0.5.0
|
||||
@ -311,6 +363,10 @@ containerEdits:
|
||||
|
||||
err := c.validateFlags(nil, &tc.options)
|
||||
require.ErrorIs(t, err, tc.expectedValidateError)
|
||||
// Set the ldconfig path to empty.
|
||||
// This is required during test because config.GetConfig() returns
|
||||
// the default ldconfig path, even if it is not set in the config file.
|
||||
tc.options.ldconfigPath = ""
|
||||
require.EqualValues(t, tc.expectedOptions, tc.options)
|
||||
|
||||
// Set up a mock server, reusing the DGX A100 mock.
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"github.com/urfave/cli/v2"
|
||||
"tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
|
||||
ctkconfig "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
)
|
||||
|
||||
@ -34,6 +35,11 @@ type config struct {
|
||||
cdiSpecDirs cli.StringSlice
|
||||
}
|
||||
|
||||
// SetCDISpecDirs sets the cdiSpecDirs field from a []string
|
||||
func (c *config) SetCDISpecDirs(dirs []string) {
|
||||
c.cdiSpecDirs = *cli.NewStringSlice(dirs...)
|
||||
}
|
||||
|
||||
// NewCommand constructs a cdi list command with the specified logger
|
||||
func NewCommand(logger logger.Interface) *cli.Command {
|
||||
c := command{
|
||||
@ -64,16 +70,27 @@ func (m command) build() *cli.Command {
|
||||
Usage: "specify the directories to scan for CDI specifications",
|
||||
Value: cli.NewStringSlice(cdi.DefaultSpecDirs...),
|
||||
Destination: &cfg.cdiSpecDirs,
|
||||
EnvVars: []string{"NVIDIA_CTK_CDI_SPEC_DIRS"},
|
||||
},
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (m command) validateFlags(c *cli.Context, cfg *config) error {
|
||||
func (m command) validateFlags(ctx *cli.Context, cfg *config) error {
|
||||
// Load config file as base configuration
|
||||
c, err := ctkconfig.GetConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Use centralized normalization
|
||||
ctkconfig.ResolveCDIListConfig(ctx, c, cfg)
|
||||
|
||||
if len(cfg.cdiSpecDirs.Value()) == 0 {
|
||||
return errors.New("at least one CDI specification directory must be specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
95
cmd/nvidia-ctk/cdi/list/list_test.go
Normal file
95
cmd/nvidia-ctk/cdi/list/list_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
testlog "github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
func TestValidateFlags(t *testing.T) {
|
||||
logger, _ := testlog.NewNullLogger()
|
||||
// Create a temporary directory for config
|
||||
tmpDir, err := os.MkdirTemp("", "nvidia-container-toolkit-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a temporary config file
|
||||
configContent := `
|
||||
[nvidia-container-runtime]
|
||||
mode = "cdi"
|
||||
[[nvidia-container-runtime.modes.cdi]]
|
||||
spec-dirs = ["/etc/cdi", "/usr/local/cdi"]
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
err = os.WriteFile(configPath, []byte(configContent), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set XDG_CONFIG_HOME to point to our temporary directory
|
||||
oldXDGConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", tmpDir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", oldXDGConfigHome)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cliArgs []string
|
||||
envVars map[string]string
|
||||
expectedDirs []string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "command line takes precedence",
|
||||
cliArgs: []string{"--spec-dir=/custom/dir1", "--spec-dir=/custom/dir2"},
|
||||
expectedDirs: []string{"/custom/dir1", "/custom/dir2"},
|
||||
},
|
||||
{
|
||||
name: "environment variable takes precedence over config",
|
||||
envVars: map[string]string{"NVIDIA_CTK_CDI_SPEC_DIRS": "/env/dir1:/env/dir2"},
|
||||
expectedDirs: []string{"/env/dir1", "/env/dir2"},
|
||||
},
|
||||
{
|
||||
name: "config file used as fallback",
|
||||
expectedDirs: []string{"/etc/cdi", "/usr/local/cdi"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up environment variables
|
||||
for k, v := range tt.envVars {
|
||||
old := os.Getenv(k)
|
||||
os.Setenv(k, v)
|
||||
defer os.Setenv(k, old)
|
||||
}
|
||||
|
||||
// Create command
|
||||
cmd := NewCommand(logger)
|
||||
|
||||
// Create a new context with the command
|
||||
app := &cli.App{
|
||||
Commands: []*cli.Command{
|
||||
{
|
||||
Name: "cdi",
|
||||
Subcommands: []*cli.Command{cmd},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Run command
|
||||
args := append([]string{"nvidia-ctk", "cdi", "list"}, tt.cliArgs...)
|
||||
err := app.Run(args)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorContains)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
@ -445,11 +445,6 @@ func setGetDistIDLikeForTest(ids []string) func() {
|
||||
}
|
||||
}
|
||||
|
||||
// prt returns a reference to whatever type is passed into it
|
||||
func ptr[T any](x T) *T {
|
||||
return &x
|
||||
}
|
||||
|
||||
func setGetLdConfigPathForTest() func() {
|
||||
previous := getLdConfigPath
|
||||
getLdConfigPath = func() ldconfigPath {
|
||||
|
154
internal/config/flags.go
Normal file
154
internal/config/flags.go
Normal file
@ -0,0 +1,154 @@
|
||||
/**
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
|
||||
|
||||
cli "github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
// prt returns a reference to whatever type is passed into it
|
||||
func ptr[T any](x T) *T {
|
||||
return &x
|
||||
}
|
||||
|
||||
// ResolveCDIListConfig resolves the config struct for the CDI list subcommand in-place.
|
||||
// It sets cfg.cdiSpecDirs using CLI > config > default priority.
|
||||
// Accepts *list.config as the third argument.
|
||||
func ResolveCDIListConfig(ctx *cli.Context, config *Config, cfg interface{}) {
|
||||
// Use switch statement for type assertion
|
||||
switch v := cfg.(type) {
|
||||
case *struct{ cdiSpecDirs cli.StringSlice }:
|
||||
var dirs []string
|
||||
switch {
|
||||
case ctx.IsSet("spec-dir"):
|
||||
dirs = ctx.StringSlice("spec-dir")
|
||||
case config != nil && len(config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs) > 0:
|
||||
dirs = config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs
|
||||
default:
|
||||
dirs = []string{"/etc/cdi", "/var/run/cdi"}
|
||||
}
|
||||
v.cdiSpecDirs = *cli.NewStringSlice(dirs...)
|
||||
case interface{ SetCDISpecDirs([]string) }:
|
||||
var dirs []string
|
||||
switch {
|
||||
case ctx.IsSet("spec-dir"):
|
||||
dirs = ctx.StringSlice("spec-dir")
|
||||
case config != nil && len(config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs) > 0:
|
||||
dirs = config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs
|
||||
default:
|
||||
dirs = []string{"/etc/cdi", "/var/run/cdi"}
|
||||
}
|
||||
v.SetCDISpecDirs(dirs)
|
||||
default:
|
||||
panic("ResolveCDIListConfig: unsupported config struct type")
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveCDIGenerateOptions resolves the options struct for the CDI generate subcommand in-place.
|
||||
// It sets all fields using CLI > config > default priority.
|
||||
// Uses reflection to support unexported fields from another package.
|
||||
func ResolveCDIGenerateOptions(ctx *cli.Context, config *Config, opts interface{}) {
|
||||
// Define resolveStringSlice before use
|
||||
resolveStringSlice := func(flagName string, configVal []string, defaultVal []string) []string {
|
||||
if ctx != nil && ctx.IsSet(flagName) {
|
||||
return ctx.StringSlice(flagName)
|
||||
}
|
||||
if len(configVal) > 0 {
|
||||
return configVal
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// Always use csv.DefaultFileList() as the default for csv.file
|
||||
csvFileDefault := csv.DefaultFileList()
|
||||
csvFileConfig := []string{config.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath}
|
||||
if config.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath == "" {
|
||||
csvFileConfig = nil
|
||||
}
|
||||
|
||||
// Use type assertion for setter methods first (like list.go)
|
||||
if setter, ok := opts.(interface{ SetCSVFiles([]string) }); ok {
|
||||
setter.SetCSVFiles(resolveStringSlice("csv.file", csvFileConfig, csvFileDefault))
|
||||
}
|
||||
if setter, ok := opts.(interface{ SetCSVIgnorePatterns([]string) }); ok {
|
||||
setter.SetCSVIgnorePatterns(resolveStringSlice("csv.ignore-pattern", nil, nil))
|
||||
}
|
||||
// ... existing reflection-based logic for other fields ...
|
||||
v := reflect.ValueOf(opts)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
panic("ResolveCDIGenerateOptions: opts must be a non-nil pointer to struct")
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
panic("ResolveCDIGenerateOptions: opts must be a pointer to struct")
|
||||
}
|
||||
|
||||
setString := func(field, value string) {
|
||||
f := v.FieldByName(field)
|
||||
if f.IsValid() && f.CanSet() {
|
||||
f.SetString(value)
|
||||
}
|
||||
}
|
||||
setStringSlice := func(field string, value []string) {
|
||||
f := v.FieldByName(field)
|
||||
if f.IsValid() && f.CanSet() {
|
||||
f.Set(reflect.ValueOf(*cli.NewStringSlice(value...)))
|
||||
}
|
||||
}
|
||||
|
||||
resolveString := func(flagName, configVal, defaultVal string) string {
|
||||
if ctx.IsSet(flagName) {
|
||||
return ctx.String(flagName)
|
||||
}
|
||||
if configVal != "" {
|
||||
return configVal
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
setString("Format", resolveString("format", "", "yaml"))
|
||||
setString("Mode", resolveString("mode", config.NVIDIAContainerRuntimeConfig.Mode, "auto"))
|
||||
setString("NvidiaCDIHookPath", resolveString("nvidia-cdi-hook-path", config.NVIDIAContainerRuntimeHookConfig.Path, ""))
|
||||
setString("LdconfigPath", resolveString("ldconfig-path", string(config.NVIDIAContainerCLIConfig.Ldconfig), ""))
|
||||
setString("Vendor", resolveString("vendor", "nvidia.com", "nvidia.com"))
|
||||
setString("Class", resolveString("class", "gpu", "gpu"))
|
||||
setString("Output", resolveString("output", "", ""))
|
||||
setString("DriverRoot", resolveString("driver-root", "", ""))
|
||||
setString("DevRoot", resolveString("dev-root", "", ""))
|
||||
|
||||
setStringSlice("DeviceNameStrategies", resolveStringSlice("device-name-strategy", nil, []string{"index", "uuid"}))
|
||||
setStringSlice("ConfigSearchPaths", resolveStringSlice("config-search-path", nil, nil))
|
||||
setStringSlice("LibrarySearchPaths", resolveStringSlice("library-search-path", nil, nil))
|
||||
setStringSlice("DisabledHooks", resolveStringSlice("disable-hook", nil, nil))
|
||||
|
||||
// For reflection-based path, set csv.Files and csv.IgnorePatterns if present
|
||||
csvField := v.FieldByName("Csv")
|
||||
if csvField.IsValid() && csvField.Kind() == reflect.Struct {
|
||||
filesField := csvField.FieldByName("Files")
|
||||
if filesField.IsValid() && filesField.CanSet() {
|
||||
filesField.Set(reflect.ValueOf(*cli.NewStringSlice(resolveStringSlice("csv.file", csvFileConfig, csvFileDefault)...)))
|
||||
}
|
||||
ignorePatternsField := csvField.FieldByName("IgnorePatterns")
|
||||
if ignorePatternsField.IsValid() && ignorePatternsField.CanSet() {
|
||||
ignorePatternsField.Set(reflect.ValueOf(*cli.NewStringSlice(resolveStringSlice("csv.ignore-pattern", nil, nil)...)))
|
||||
}
|
||||
}
|
||||
}
|
357
internal/config/flags_test.go
Normal file
357
internal/config/flags_test.go
Normal file
@ -0,0 +1,357 @@
|
||||
/*
|
||||
*
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*
|
||||
*/
|
||||
package config
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
cli "github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
type mockConfig struct {
|
||||
SpecDirs []string
|
||||
Mode string
|
||||
HookPath string
|
||||
LdconfigPath string
|
||||
CSVSpecPath string
|
||||
}
|
||||
|
||||
func (m *mockConfig) toConfig() *Config {
|
||||
return &Config{
|
||||
NVIDIAContainerRuntimeConfig: RuntimeConfig{
|
||||
Mode: m.Mode,
|
||||
Modes: modesConfig{
|
||||
CDI: cdiModeConfig{
|
||||
SpecDirs: m.SpecDirs,
|
||||
},
|
||||
CSV: csvModeConfig{
|
||||
MountSpecPath: m.CSVSpecPath,
|
||||
},
|
||||
},
|
||||
},
|
||||
NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{
|
||||
Path: m.HookPath,
|
||||
},
|
||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||
Ldconfig: ldconfigPath(m.LdconfigPath),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCDIListConfig(t *testing.T) {
|
||||
app := cli.NewApp()
|
||||
app.Flags = []cli.Flag{
|
||||
&cli.StringSliceFlag{
|
||||
Name: "spec-dir",
|
||||
},
|
||||
}
|
||||
set := func(args ...string) *cli.Context {
|
||||
set := flagSet(app, args...)
|
||||
return cli.NewContext(app, set, nil)
|
||||
}
|
||||
t.Run("CLI takes precedence", func(t *testing.T) {
|
||||
ctx := set("--spec-dir", "/cli/dir1", "--spec-dir", "/cli/dir2")
|
||||
cfg := (&mockConfig{SpecDirs: []string{"/config/dir"}}).toConfig()
|
||||
var target struct{ cdiSpecDirs cli.StringSlice }
|
||||
ResolveCDIListConfig(ctx, cfg, &target)
|
||||
require.Equal(t, []string{"/cli/dir1", "/cli/dir2"}, getStringSliceFieldValue(reflect.ValueOf(target.cdiSpecDirs)))
|
||||
})
|
||||
t.Run("Config used if CLI not set", func(t *testing.T) {
|
||||
ctx := set()
|
||||
cfg := (&mockConfig{SpecDirs: []string{"/config/dir1", "/config/dir2"}}).toConfig()
|
||||
var target struct{ cdiSpecDirs cli.StringSlice }
|
||||
ResolveCDIListConfig(ctx, cfg, &target)
|
||||
require.Equal(t, []string{"/config/dir1", "/config/dir2"}, getStringSliceFieldValue(reflect.ValueOf(target.cdiSpecDirs)))
|
||||
})
|
||||
t.Run("Default used if neither set", func(t *testing.T) {
|
||||
ctx := set()
|
||||
cfg := (&mockConfig{}).toConfig()
|
||||
var target struct{ cdiSpecDirs cli.StringSlice }
|
||||
ResolveCDIListConfig(ctx, cfg, &target)
|
||||
require.Equal(t, []string{"/etc/cdi", "/var/run/cdi"}, getStringSliceFieldValue(reflect.ValueOf(target.cdiSpecDirs)))
|
||||
})
|
||||
}
|
||||
|
||||
// Helper for safely extracting []string from a reflect.Value of cli.StringSlice or *cli.StringSlice
|
||||
func getStringSliceFieldValue(v reflect.Value) []string {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Interface {
|
||||
v = v.Elem()
|
||||
}
|
||||
ss, ok := v.Interface().(cli.StringSlice)
|
||||
if ok {
|
||||
return ss.Value()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper for safely extracting string from a reflect.Value of string or *string
|
||||
func getStringFieldValue(v reflect.Value) string {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Interface {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.String {
|
||||
return v.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// optsWithSetters is used to test setter-based normalization
|
||||
type optsWithSetters struct {
|
||||
csvFiles []string
|
||||
csvIgnorePatterns []string
|
||||
}
|
||||
|
||||
// Implement the setter methods
|
||||
func (o *optsWithSetters) SetCSVFiles(files []string) { o.csvFiles = files }
|
||||
func (o *optsWithSetters) SetCSVIgnorePatterns(patterns []string) { o.csvIgnorePatterns = patterns }
|
||||
|
||||
func TestResolveCDIGenerateOptions(t *testing.T) {
|
||||
app := cli.NewApp()
|
||||
app.Flags = []cli.Flag{
|
||||
&cli.StringSliceFlag{Name: "config-search-path"},
|
||||
&cli.StringFlag{Name: "format"},
|
||||
&cli.StringFlag{Name: "mode"},
|
||||
&cli.StringSliceFlag{Name: "device-name-strategy"},
|
||||
&cli.StringFlag{Name: "nvidia-cdi-hook-path"},
|
||||
&cli.StringFlag{Name: "ldconfig-path"},
|
||||
&cli.StringFlag{Name: "vendor"},
|
||||
&cli.StringFlag{Name: "class"},
|
||||
&cli.StringSliceFlag{Name: "library-search-path"},
|
||||
&cli.StringSliceFlag{Name: "csv.file"},
|
||||
&cli.StringSliceFlag{Name: "csv.ignore-pattern"},
|
||||
&cli.StringSliceFlag{Name: "disable-hook"},
|
||||
&cli.StringFlag{Name: "output"},
|
||||
&cli.StringFlag{Name: "driver-root"},
|
||||
&cli.StringFlag{Name: "dev-root"},
|
||||
}
|
||||
set := func(args ...string) *cli.Context {
|
||||
set := flagSet(app, args...)
|
||||
return cli.NewContext(app, set, nil)
|
||||
}
|
||||
cfg := (&mockConfig{
|
||||
SpecDirs: []string{"/config/dir"},
|
||||
Mode: "configmode",
|
||||
HookPath: "/config/hook",
|
||||
LdconfigPath: "/config/ldconfig",
|
||||
CSVSpecPath: "/config/csv",
|
||||
}).toConfig()
|
||||
|
||||
t.Run("All CLI flags", func(t *testing.T) {
|
||||
ctx := set(
|
||||
"--config-search-path", "/cli/cfg1", "--config-search-path", "/cli/cfg2",
|
||||
"--format", "json",
|
||||
"--mode", "climode",
|
||||
"--device-name-strategy", "uuid",
|
||||
"--nvidia-cdi-hook-path", "/cli/hook",
|
||||
"--ldconfig-path", "/cli/ldconfig",
|
||||
"--vendor", "cli-vendor",
|
||||
"--class", "cli-class",
|
||||
"--library-search-path", "/cli/lib1",
|
||||
"--csv.file", "/cli/csv1",
|
||||
"--csv.ignore-pattern", "pat1",
|
||||
"--disable-hook", "hook1",
|
||||
"--output", "/cli/output",
|
||||
"--driver-root", "/cli/driver",
|
||||
"--dev-root", "/cli/dev",
|
||||
)
|
||||
var opts struct {
|
||||
Output string
|
||||
Format string
|
||||
DeviceNameStrategies cli.StringSlice
|
||||
DriverRoot string
|
||||
DevRoot string
|
||||
NvidiaCDIHookPath string
|
||||
LdconfigPath string
|
||||
Mode string
|
||||
Vendor string
|
||||
Class string
|
||||
ConfigSearchPaths cli.StringSlice
|
||||
LibrarySearchPaths cli.StringSlice
|
||||
DisabledHooks cli.StringSlice
|
||||
Csv struct {
|
||||
Files cli.StringSlice
|
||||
IgnorePatterns cli.StringSlice
|
||||
}
|
||||
}
|
||||
ResolveCDIGenerateOptions(ctx, cfg, &opts)
|
||||
// Use reflection to check values
|
||||
v := reflect.ValueOf(&opts).Elem()
|
||||
field := v.FieldByName("Format")
|
||||
require.Equal(t, "json", getStringFieldValue(field))
|
||||
field = v.FieldByName("Mode")
|
||||
require.Equal(t, "climode", getStringFieldValue(field))
|
||||
field = v.FieldByName("NvidiaCDIHookPath")
|
||||
require.Equal(t, "/cli/hook", getStringFieldValue(field))
|
||||
field = v.FieldByName("LdconfigPath")
|
||||
require.Equal(t, "/cli/ldconfig", getStringFieldValue(field))
|
||||
field = v.FieldByName("Vendor")
|
||||
require.Equal(t, "cli-vendor", getStringFieldValue(field))
|
||||
field = v.FieldByName("Class")
|
||||
require.Equal(t, "cli-class", getStringFieldValue(field))
|
||||
field = v.FieldByName("Output")
|
||||
require.Equal(t, "/cli/output", getStringFieldValue(field))
|
||||
field = v.FieldByName("DriverRoot")
|
||||
require.Equal(t, "/cli/driver", getStringFieldValue(field))
|
||||
field = v.FieldByName("DevRoot")
|
||||
require.Equal(t, "/cli/dev", getStringFieldValue(field))
|
||||
require.Equal(t, []string{"uuid"}, getStringSliceFieldValue(v.FieldByName("DeviceNameStrategies")))
|
||||
require.Equal(t, []string{"/cli/cfg1", "/cli/cfg2"}, getStringSliceFieldValue(v.FieldByName("ConfigSearchPaths")))
|
||||
require.Equal(t, []string{"/cli/lib1"}, getStringSliceFieldValue(v.FieldByName("LibrarySearchPaths")))
|
||||
require.Equal(t, []string{"hook1"}, getStringSliceFieldValue(v.FieldByName("DisabledHooks")))
|
||||
csvField := v.FieldByName("Csv")
|
||||
requireStringSliceEqual(t, []string{"/cli/csv1"}, getStringSliceFieldValue(csvField.FieldByName("Files")))
|
||||
requireStringSliceEqual(t, []string{"pat1"}, getStringSliceFieldValue(csvField.FieldByName("IgnorePatterns")))
|
||||
})
|
||||
|
||||
t.Run("Config fallback", func(t *testing.T) {
|
||||
ctx := set()
|
||||
var opts struct {
|
||||
Output string
|
||||
Format string
|
||||
DeviceNameStrategies cli.StringSlice
|
||||
DriverRoot string
|
||||
DevRoot string
|
||||
NvidiaCDIHookPath string
|
||||
LdconfigPath string
|
||||
Mode string
|
||||
Vendor string
|
||||
Class string
|
||||
ConfigSearchPaths cli.StringSlice
|
||||
LibrarySearchPaths cli.StringSlice
|
||||
DisabledHooks cli.StringSlice
|
||||
Csv struct {
|
||||
Files cli.StringSlice
|
||||
IgnorePatterns cli.StringSlice
|
||||
}
|
||||
}
|
||||
ResolveCDIGenerateOptions(ctx, cfg, &opts)
|
||||
v := reflect.ValueOf(&opts).Elem()
|
||||
require.Equal(t, "configmode", getStringFieldValue(v.FieldByName("Mode")))
|
||||
require.Equal(t, "/config/hook", getStringFieldValue(v.FieldByName("NvidiaCDIHookPath")))
|
||||
require.Equal(t, "/config/ldconfig", getStringFieldValue(v.FieldByName("LdconfigPath")))
|
||||
csvField := v.FieldByName("Csv")
|
||||
requireStringSliceEqual(t, []string{"/config/csv"}, getStringSliceFieldValue(csvField.FieldByName("Files")))
|
||||
requireStringSliceEqual(t, []string{}, getStringSliceFieldValue(csvField.FieldByName("IgnorePatterns")))
|
||||
})
|
||||
|
||||
t.Run("Default fallback", func(t *testing.T) {
|
||||
ctx := set()
|
||||
cfg := (&mockConfig{}).toConfig()
|
||||
var opts struct {
|
||||
Output string
|
||||
Format string
|
||||
DeviceNameStrategies cli.StringSlice
|
||||
DriverRoot string
|
||||
DevRoot string
|
||||
NvidiaCDIHookPath string
|
||||
LdconfigPath string
|
||||
Mode string
|
||||
Vendor string
|
||||
Class string
|
||||
ConfigSearchPaths cli.StringSlice
|
||||
LibrarySearchPaths cli.StringSlice
|
||||
DisabledHooks cli.StringSlice
|
||||
Csv struct {
|
||||
Files cli.StringSlice
|
||||
IgnorePatterns cli.StringSlice
|
||||
}
|
||||
}
|
||||
ResolveCDIGenerateOptions(ctx, cfg, &opts)
|
||||
v := reflect.ValueOf(&opts).Elem()
|
||||
require.Equal(t, "auto", getStringFieldValue(v.FieldByName("Mode")))
|
||||
require.Equal(t, "yaml", getStringFieldValue(v.FieldByName("Format")))
|
||||
require.Equal(t, []string{"index", "uuid"}, getStringSliceFieldValue(v.FieldByName("DeviceNameStrategies")))
|
||||
require.Equal(t, "nvidia.com", getStringFieldValue(v.FieldByName("Vendor")))
|
||||
require.Equal(t, "gpu", getStringFieldValue(v.FieldByName("Class")))
|
||||
csvField := v.FieldByName("Csv")
|
||||
requireStringSliceEqual(t, csv.DefaultFileList(), getStringSliceFieldValue(csvField.FieldByName("Files")))
|
||||
requireStringSliceEqual(t, []string{}, getStringSliceFieldValue(csvField.FieldByName("IgnorePatterns")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveCDIGenerateOptions_SetterMethods(t *testing.T) {
|
||||
app := cli.NewApp()
|
||||
app.Flags = []cli.Flag{
|
||||
&cli.StringSliceFlag{Name: "csv.file"},
|
||||
&cli.StringSliceFlag{Name: "csv.ignore-pattern"},
|
||||
}
|
||||
set := func(args ...string) *cli.Context {
|
||||
set := flagSet(app, args...)
|
||||
return cli.NewContext(app, set, nil)
|
||||
}
|
||||
cfg := (&mockConfig{
|
||||
CSVSpecPath: "/config/csv",
|
||||
}).toConfig()
|
||||
|
||||
t.Run("CLI takes precedence", func(t *testing.T) {
|
||||
ctx := set("--csv.file", "/cli/csv1", "--csv.file", "/cli/csv2", "--csv.ignore-pattern", "pat1")
|
||||
opts := &optsWithSetters{}
|
||||
ResolveCDIGenerateOptions(ctx, cfg, opts)
|
||||
requireStringSliceEqual(t, []string{"/cli/csv1", "/cli/csv2"}, opts.csvFiles)
|
||||
requireStringSliceEqual(t, []string{"pat1"}, opts.csvIgnorePatterns)
|
||||
})
|
||||
|
||||
t.Run("Config fallback", func(t *testing.T) {
|
||||
ctx := set()
|
||||
opts := &optsWithSetters{}
|
||||
ResolveCDIGenerateOptions(ctx, cfg, opts)
|
||||
requireStringSliceEqual(t, []string{"/config/csv"}, opts.csvFiles)
|
||||
requireStringSliceEqual(t, []string{}, opts.csvIgnorePatterns)
|
||||
})
|
||||
|
||||
t.Run("Default fallback", func(t *testing.T) {
|
||||
ctx := set()
|
||||
cfg := (&mockConfig{}).toConfig()
|
||||
opts := &optsWithSetters{}
|
||||
ResolveCDIGenerateOptions(ctx, cfg, opts)
|
||||
requireStringSliceEqual(t, csv.DefaultFileList(), opts.csvFiles)
|
||||
requireStringSliceEqual(t, []string{}, opts.csvIgnorePatterns)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to create a cli.FlagSet for testing
|
||||
func flagSet(app *cli.App, args ...string) *flag.FlagSet {
|
||||
set := flag.NewFlagSet(app.Name, flag.ContinueOnError)
|
||||
for _, f := range app.Flags {
|
||||
_ = f.Apply(set)
|
||||
}
|
||||
_ = set.Parse(args)
|
||||
return set
|
||||
}
|
||||
|
||||
// Helper to compare two string slices, treating nil and empty as equal
|
||||
func requireStringSliceEqual(t *testing.T, expected, actual []string, msgAndArgs ...interface{}) {
|
||||
if expected == nil {
|
||||
expected = []string{}
|
||||
}
|
||||
if actual == nil {
|
||||
actual = []string{}
|
||||
}
|
||||
require.Equal(t, expected, actual, msgAndArgs...)
|
||||
}
|
Loading…
Reference in New Issue
Block a user