diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 9daf97ca..7aef47b8 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -36,6 +36,7 @@ import ( ) const ( + discoveryModeAuto = "auto" discoveryModeNVML = "nvml" discoveryModeWSL = "wsl" @@ -96,8 +97,8 @@ func (m command) build() *cli.Command { }, &cli.StringFlag{ Name: "discovery-mode", - Usage: "The mode to use when discovering the available entities. One of [nvml | wsl]", - Value: discoveryModeNVML, + Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. I mode is set to 'auto' the mode will be determined based on the system configuration.", + Value: discoveryModeAuto, Destination: &cfg.discoveryMode, }, &cli.StringFlag{ @@ -132,6 +133,7 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { cfg.discoveryMode = strings.ToLower(cfg.discoveryMode) switch cfg.discoveryMode { + case discoveryModeAuto: case discoveryModeNVML: case discoveryModeWSL: default: diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 4081e524..f3d7dba0 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -19,6 +19,7 @@ package nvcdi import ( "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/info" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) @@ -39,7 +40,7 @@ func New(opts ...Option) Interface { opt(l) } if l.mode == "" { - l.mode = "nvml" + l.mode = "auto" } if l.logger == nil { l.logger = logrus.StandardLogger() @@ -54,7 +55,7 @@ func New(opts ...Option) Interface { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } - switch l.mode { + switch l.resolveMode() { case "nvml": if l.nvmllib == nil { l.nvmllib = nvml.New() @@ -71,3 +72,24 @@ func New(opts ...Option) Interface { // TODO: We want an error here. return nil } + +// resolveMode resolves the mode for CDI spec generation based on the current system. +func (l *nvcdilib) resolveMode() (rmode string) { + if l.mode != "auto" { + return l.mode + } + defer func() { + l.logger.Infof("Auto-detected mode as %q", rmode) + }() + + nvinfo := info.New() + + isWSL, reason := nvinfo.HasDXCore() + l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason) + + if isWSL { + return "wsl" + } + + return "nvml" +}