diff --git a/cmd/nvidia-ctk/cdi/list/list.go b/cmd/nvidia-ctk/cdi/list/list.go index 1f9d39f8..74f4e48a 100644 --- a/cmd/nvidia-ctk/cdi/list/list.go +++ b/cmd/nvidia-ctk/cdi/list/list.go @@ -17,6 +17,7 @@ package list import ( + "errors" "fmt" "github.com/urfave/cli/v2" @@ -29,7 +30,9 @@ type command struct { logger logger.Interface } -type config struct{} +type config struct { + cdiSpecDirs cli.StringSlice +} // NewCommand constructs a cdi list command with the specified logger func NewCommand(logger logger.Interface) *cli.Command { @@ -55,19 +58,29 @@ func (m command) build() *cli.Command { }, } - c.Flags = []cli.Flag{} + c.Flags = []cli.Flag{ + &cli.StringSliceFlag{ + Name: "spec-dir", + Usage: "specify the directories to scan for CDI specifications", + Value: cli.NewStringSlice(cdi.DefaultSpecDirs...), + Destination: &cfg.cdiSpecDirs, + }, + } return &c } func (m command) validateFlags(c *cli.Context, cfg *config) error { + if len(cfg.cdiSpecDirs.Value()) == 0 { + return errors.New("at least one CDI specification directory must be specified") + } return nil } func (m command) run(c *cli.Context, cfg *config) error { registry, err := cdi.NewCache( cdi.WithAutoRefresh(false), - cdi.WithSpecDirs(cdi.DefaultSpecDirs...), + cdi.WithSpecDirs(cfg.cdiSpecDirs.Value()...), ) if err != nil { return fmt.Errorf("failed to create CDI cache: %v", err)