diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index b187335b..b74bd690 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -182,6 +182,31 @@ func (m command) build() *cli.Command { } func (m command) validateFlags(c *cli.Context, opts *options) error { + // Load config file as base configuration + cfg, err := config.GetConfig() + if err != nil { + return fmt.Errorf("failed to load config: %v", err) + } + + // Apply config file values if command line or environment variables are not set. + // order (1) command line, (2) environment variable, (3) config file + if opts.nvidiaCDIHookPath == "" && cfg.NVIDIAContainerRuntimeHookConfig.Path != "" { + opts.nvidiaCDIHookPath = cfg.NVIDIAContainerRuntimeHookConfig.Path + } + + if opts.ldconfigPath == "" && string(cfg.NVIDIAContainerCLIConfig.Ldconfig) != "" { + opts.ldconfigPath = string(cfg.NVIDIAContainerCLIConfig.Ldconfig) + } + + if opts.mode == "" && cfg.NVIDIAContainerRuntimeConfig.Mode != "" { + opts.mode = cfg.NVIDIAContainerRuntimeConfig.Mode + } + + if opts.csv.files.Value() == nil && len(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) > 0 { + opts.csv.files = *cli.NewStringSlice(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) + } + + // Continue with existing validation opts.format = strings.ToLower(opts.format) switch opts.format { case spec.FormatJSON: diff --git a/cmd/nvidia-ctk/cdi/generate/generate_test.go b/cmd/nvidia-ctk/cdi/generate/generate_test.go index d6aae4d7..17f53d0b 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate_test.go +++ b/cmd/nvidia-ctk/cdi/generate/generate_test.go @@ -18,6 +18,7 @@ package generate import ( "bytes" + "os" "path/filepath" "strings" "testing" @@ -26,11 +27,35 @@ import ( "github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" "github.com/NVIDIA/nvidia-container-toolkit/internal/test" ) func TestGenerateSpec(t *testing.T) { + // Create a temporary directory for config + tmpDir, err := os.MkdirTemp("", "nvidia-container-toolkit-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create a temporary config file + configContent := ` +[nvidia-container-runtime] +mode = "nvml" +[[nvidia-container-runtime.modes.cdi]] +spec-dirs = ["/etc/cdi", "/usr/local/cdi"] +[nvidia-container-runtime.modes.csv] +mount-spec-path = "/etc/nvidia-container-runtime/host-files-for-container.d" + ` + configPath := filepath.Join(tmpDir, "config.toml") + err = os.WriteFile(configPath, []byte(configContent), 0600) + require.NoError(t, err) + + // Set XDG_CONFIG_HOME to point to our temporary directory + oldXDGConfigHome := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", tmpDir) + defer os.Setenv("XDG_CONFIG_HOME", oldXDGConfigHome) + t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true") moduleRoot, err := test.GetModuleRoot() require.NoError(t, err) @@ -62,6 +87,13 @@ func TestGenerateSpec(t *testing.T) { class: "device", nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook", driverRoot: driverRoot, + csv: struct { + files cli.StringSlice + ignorePatterns cli.StringSlice + }{ + files: *cli.NewStringSlice("/etc/nvidia-container-runtime/host-files-for-container.d"), + ignorePatterns: cli.StringSlice{}, + }, }, expectedSpec: `--- cdiVersion: 0.5.0 @@ -125,6 +157,10 @@ containerEdits: err := c.validateFlags(nil, &tc.options) require.ErrorIs(t, err, tc.expectedValidateError) + // Set the ldconfig path to empty. + // This is required during test because config.GetConfig() returns + // the default ldconfig path, even if it is not set in the config file. + tc.options.ldconfigPath = "" require.EqualValues(t, tc.expectedOptions, tc.options) // Set up a mock server, reusing the DGX A100 mock. diff --git a/cmd/nvidia-ctk/cdi/list/list.go b/cmd/nvidia-ctk/cdi/list/list.go index 886da6e9..281009c0 100644 --- a/cmd/nvidia-ctk/cdi/list/list.go +++ b/cmd/nvidia-ctk/cdi/list/list.go @@ -23,6 +23,7 @@ import ( "github.com/urfave/cli/v2" "tags.cncf.io/container-device-interface/pkg/cdi" + ctkconfig "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) @@ -64,6 +65,7 @@ func (m command) build() *cli.Command { Usage: "specify the directories to scan for CDI specifications", Value: cli.NewStringSlice(cdi.DefaultSpecDirs...), Destination: &cfg.cdiSpecDirs, + EnvVars: []string{"NVIDIA_CTK_CDI_SPEC_DIRS"}, }, } @@ -71,9 +73,22 @@ func (m command) build() *cli.Command { } func (m command) validateFlags(c *cli.Context, cfg *config) error { + // Load config file as base configuration + config, err := ctkconfig.GetConfig() + if err != nil { + return fmt.Errorf("failed to load config: %v", err) + } + + // Apply config file values if command line or environment variables are not set. + // order (1) command line, (2) environment variable, (3) config file + if !c.IsSet("spec-dir") && len(config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs) > 0 { + cfg.cdiSpecDirs = *cli.NewStringSlice(config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs...) + } + if len(cfg.cdiSpecDirs.Value()) == 0 { return errors.New("at least one CDI specification directory must be specified") } + return nil } diff --git a/cmd/nvidia-ctk/cdi/list/list_test.go b/cmd/nvidia-ctk/cdi/list/list_test.go new file mode 100644 index 00000000..907c1661 --- /dev/null +++ b/cmd/nvidia-ctk/cdi/list/list_test.go @@ -0,0 +1,95 @@ +package list + +import ( + "os" + "path/filepath" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" +) + +func TestValidateFlags(t *testing.T) { + logger, _ := testlog.NewNullLogger() + // Create a temporary directory for config + tmpDir, err := os.MkdirTemp("", "nvidia-container-toolkit-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create a temporary config file + configContent := ` +[nvidia-container-runtime] +mode = "cdi" +[[nvidia-container-runtime.modes.cdi]] +spec-dirs = ["/etc/cdi", "/usr/local/cdi"] +` + configPath := filepath.Join(tmpDir, "config.toml") + err = os.WriteFile(configPath, []byte(configContent), 0600) + require.NoError(t, err) + + // Set XDG_CONFIG_HOME to point to our temporary directory + oldXDGConfigHome := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", tmpDir) + defer os.Setenv("XDG_CONFIG_HOME", oldXDGConfigHome) + + tests := []struct { + name string + cliArgs []string + envVars map[string]string + expectedDirs []string + expectError bool + errorContains string + }{ + { + name: "command line takes precedence", + cliArgs: []string{"--spec-dir=/custom/dir1", "--spec-dir=/custom/dir2"}, + expectedDirs: []string{"/custom/dir1", "/custom/dir2"}, + }, + { + name: "environment variable takes precedence over config", + envVars: map[string]string{"NVIDIA_CTK_CDI_SPEC_DIRS": "/env/dir1:/env/dir2"}, + expectedDirs: []string{"/env/dir1", "/env/dir2"}, + }, + { + name: "config file used as fallback", + expectedDirs: []string{"/etc/cdi", "/usr/local/cdi"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment variables + for k, v := range tt.envVars { + old := os.Getenv(k) + os.Setenv(k, v) + defer os.Setenv(k, old) + } + + // Create command + cmd := NewCommand(logger) + + // Create a new context with the command + app := &cli.App{ + Commands: []*cli.Command{ + { + Name: "cdi", + Subcommands: []*cli.Command{cmd}, + }, + }, + } + + // Run command + args := append([]string{"nvidia-ctk", "cdi", "list"}, tt.cliArgs...) + err := app.Run(args) + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorContains) + return + } + + require.NoError(t, err) + }) + } +}