diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 7d27685a..cab3f97a 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" @@ -48,6 +49,10 @@ type options struct { mode string vendor string class string + + csv struct { + files cli.StringSlice + } } // NewCommand constructs a generate-cdi command with the specified logger @@ -123,13 +128,18 @@ func (m command) build() *cli.Command { Value: "gpu", Destination: &opts.class, }, + &cli.StringSliceFlag{ + Name: "csv.file", + Usage: "The path to the list of CSV files to use when generating the CDI specification in CDI mode.", + Value: cli.NewStringSlice(csv.DefaultFileList()...), + Destination: &opts.csv.files, + }, } return &c } func (m command) validateFlags(c *cli.Context, opts *options) error { - opts.format = strings.ToLower(opts.format) switch opts.format { case spec.FormatJSON: @@ -141,6 +151,7 @@ func (m command) validateFlags(c *cli.Context, opts *options) error { opts.mode = strings.ToLower(opts.mode) switch opts.mode { case nvcdi.ModeAuto: + case nvcdi.ModeCSV: case nvcdi.ModeNvml: case nvcdi.ModeWsl: case nvcdi.ModeManagement: @@ -215,6 +226,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) { nvcdi.WithNVIDIACTKPath(opts.nvidiaCTKPath), nvcdi.WithDeviceNamer(deviceNamer), nvcdi.WithMode(string(opts.mode)), + nvcdi.WithCSVFiles(opts.csv.files.Value()), ) if err != nil { return nil, fmt.Errorf("failed to create CDI library: %v", err) diff --git a/internal/discover/csv/csv.go b/internal/discover/csv/csv.go index 4aa50828..64bc34ff 100644 --- a/internal/discover/csv/csv.go +++ b/internal/discover/csv/csv.go @@ -33,6 +33,22 @@ const ( DefaultMountSpecPath = "/etc/nvidia-container-runtime/host-files-for-container.d" ) +// DefaultFileList returns the list of CSV files that are used by default. +func DefaultFileList() []string { + files := []string{ + "devices.csv", + "drivers.csv", + "l4t.csv", + } + + var paths []string + for _, file := range files { + paths = append(paths, filepath.Join(DefaultMountSpecPath, file)) + } + + return paths +} + // GetFileList returns the (non-recursive) list of CSV files in the specified // folder func GetFileList(root string) ([]string, error) { diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 37ca46e7..7d11a8e2 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -19,6 +19,7 @@ package nvcdi import ( "fmt" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/sirupsen/logrus" @@ -84,11 +85,7 @@ func New(opts ...Option) (Interface, error) { switch l.resolveMode() { case ModeCSV: if len(l.csvFiles) == 0 { - l.csvFiles = []string{ - "/etc/nvidia-container-runtime/host-files-for-container.d/l4t.csv", - "/etc/nvidia-container-runtime/host-files-for-container.d/drivers.csv", - "/etc/nvidia-container-runtime/host-files-for-container.d/devices.csv", - } + l.csvFiles = csv.DefaultFileList() } lib = (*csvlib)(l) case ModeManagement: