diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index a81f41a5..b88c9234 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -22,7 +22,7 @@ import ( "path/filepath" "strings" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" @@ -40,7 +40,7 @@ type command struct { logger *logrus.Logger } -type config struct { +type options struct { output string format string deviceNameStrategy string @@ -61,17 +61,17 @@ func NewCommand(logger *logrus.Logger) *cli.Command { // build creates the CLI command func (m command) build() *cli.Command { - cfg := config{} + opts := options{} // Create the 'generate-cdi' command c := cli.Command{ Name: "generate", Usage: "Generate CDI specifications for use with CDI-enabled runtimes", Before: func(c *cli.Context) error { - return m.validateFlags(c, &cfg) + return m.validateFlags(c, &opts) }, Action: func(c *cli.Context) error { - return m.run(c, &cfg) + return m.run(c, &opts) }, } @@ -79,109 +79,109 @@ func (m command) build() *cli.Command { &cli.StringFlag{ Name: "output", Usage: "Specify the file to output the generated CDI specification to. If this is '' the specification is output to STDOUT", - Destination: &cfg.output, + Destination: &opts.output, }, &cli.StringFlag{ Name: "format", Usage: "The output format for the generated spec [json | yaml]. This overrides the format defined by the output file extension (if specified).", Value: spec.FormatYAML, - Destination: &cfg.format, + Destination: &opts.format, }, &cli.StringFlag{ Name: "mode", Aliases: []string{"discovery-mode"}, Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.", Value: nvcdi.ModeAuto, - Destination: &cfg.mode, + Destination: &opts.mode, }, &cli.StringFlag{ Name: "device-name-strategy", Usage: "Specify the strategy for generating device names. One of [index | uuid | type-index]", Value: nvcdi.DeviceNameStrategyIndex, - Destination: &cfg.deviceNameStrategy, + Destination: &opts.deviceNameStrategy, }, &cli.StringFlag{ Name: "driver-root", Usage: "Specify the NVIDIA GPU driver root to use when discovering the entities that should be included in the CDI specification.", - Destination: &cfg.driverRoot, + Destination: &opts.driverRoot, }, &cli.StringFlag{ Name: "nvidia-ctk-path", Usage: "Specify the path to use for the nvidia-ctk in the generated CDI specification. If this is left empty, the path will be searched.", - Destination: &cfg.nvidiaCTKPath, + Destination: &opts.nvidiaCTKPath, }, &cli.StringFlag{ Name: "vendor", Aliases: []string{"cdi-vendor"}, Usage: "the vendor string to use for the generated CDI specification.", Value: "nvidia.com", - Destination: &cfg.vendor, + Destination: &opts.vendor, }, &cli.StringFlag{ Name: "class", Aliases: []string{"cdi-class"}, Usage: "the class string to use for the generated CDI specification.", Value: "gpu", - Destination: &cfg.class, + Destination: &opts.class, }, } return &c } -func (m command) validateFlags(c *cli.Context, cfg *config) error { +func (m command) validateFlags(c *cli.Context, opts *options) error { - cfg.format = strings.ToLower(cfg.format) - switch cfg.format { + opts.format = strings.ToLower(opts.format) + switch opts.format { case spec.FormatJSON: case spec.FormatYAML: default: - return fmt.Errorf("invalid output format: %v", cfg.format) + return fmt.Errorf("invalid output format: %v", opts.format) } - cfg.mode = strings.ToLower(cfg.mode) - switch cfg.mode { + opts.mode = strings.ToLower(opts.mode) + switch opts.mode { case nvcdi.ModeAuto: case nvcdi.ModeNvml: case nvcdi.ModeWsl: case nvcdi.ModeManagement: default: - return fmt.Errorf("invalid discovery mode: %v", cfg.mode) + return fmt.Errorf("invalid discovery mode: %v", opts.mode) } - _, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy) + _, err := nvcdi.NewDeviceNamer(opts.deviceNameStrategy) if err != nil { return err } - cfg.nvidiaCTKPath = discover.FindNvidiaCTK(m.logger, cfg.nvidiaCTKPath) + opts.nvidiaCTKPath = config.ResolveNVIDIACTKPath(m.logger, opts.nvidiaCTKPath) - if outputFileFormat := formatFromFilename(cfg.output); outputFileFormat != "" { + if outputFileFormat := formatFromFilename(opts.output); outputFileFormat != "" { m.logger.Debugf("Inferred output format as %q from output file name", outputFileFormat) if !c.IsSet("format") { - cfg.format = outputFileFormat - } else if outputFileFormat != cfg.format { - m.logger.Warningf("Requested output format %q does not match format implied by output file name: %q", cfg.format, outputFileFormat) + opts.format = outputFileFormat + } else if outputFileFormat != opts.format { + m.logger.Warningf("Requested output format %q does not match format implied by output file name: %q", opts.format, outputFileFormat) } } - if err := cdi.ValidateVendorName(cfg.vendor); err != nil { + if err := cdi.ValidateVendorName(opts.vendor); err != nil { return fmt.Errorf("invalid CDI vendor name: %v", err) } - if err := cdi.ValidateClassName(cfg.class); err != nil { + if err := cdi.ValidateClassName(opts.class); err != nil { return fmt.Errorf("invalid CDI class name: %v", err) } return nil } -func (m command) run(c *cli.Context, cfg *config) error { - spec, err := m.generateSpec(cfg) +func (m command) run(c *cli.Context, opts *options) error { + spec, err := m.generateSpec(opts) if err != nil { return fmt.Errorf("failed to generate CDI spec: %v", err) } m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version) - if cfg.output == "" { + if opts.output == "" { _, err := spec.WriteTo(os.Stdout) if err != nil { return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err) @@ -189,7 +189,7 @@ func (m command) run(c *cli.Context, cfg *config) error { return nil } - return spec.Save(cfg.output) + return spec.Save(opts.output) } func formatFromFilename(filename string) string { @@ -204,18 +204,18 @@ func formatFromFilename(filename string) string { return "" } -func (m command) generateSpec(cfg *config) (spec.Interface, error) { - deviceNamer, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy) +func (m command) generateSpec(opts *options) (spec.Interface, error) { + deviceNamer, err := nvcdi.NewDeviceNamer(opts.deviceNameStrategy) if err != nil { return nil, fmt.Errorf("failed to create device namer: %v", err) } cdilib, err := nvcdi.New( nvcdi.WithLogger(m.logger), - nvcdi.WithDriverRoot(cfg.driverRoot), - nvcdi.WithNVIDIACTKPath(cfg.nvidiaCTKPath), + nvcdi.WithDriverRoot(opts.driverRoot), + nvcdi.WithNVIDIACTKPath(opts.nvidiaCTKPath), nvcdi.WithDeviceNamer(deviceNamer), - nvcdi.WithMode(string(cfg.mode)), + nvcdi.WithMode(string(opts.mode)), ) if err != nil { return nil, fmt.Errorf("failed to create CDI library: %v", err) @@ -246,11 +246,11 @@ func (m command) generateSpec(cfg *config) (spec.Interface, error) { } return spec.New( - spec.WithVendor(cfg.vendor), - spec.WithClass(cfg.class), + spec.WithVendor(opts.vendor), + spec.WithClass(opts.class), spec.WithDeviceSpecs(deviceSpecs), spec.WithEdits(*commonEdits.ContainerEdits), - spec.WithFormat(cfg.format), + spec.WithFormat(opts.format), spec.WithPermissions(0644), ) } diff --git a/internal/config/config.go b/internal/config/config.go index 2ecfa120..7601d0fd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,15 +22,21 @@ import ( "io" "os" "path" + "path/filepath" "strings" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/pelletier/go-toml" + "github.com/sirupsen/logrus" ) const ( configOverride = "XDG_CONFIG_HOME" configFilePath = "nvidia-container-runtime/config.toml" + + nvidiaCTKExecutable = "nvidia-ctk" + nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk" ) var ( @@ -181,7 +187,7 @@ func GetDefaultConfigToml() (*toml.Tree, error) { tree.Set("nvidia-container-runtime.modes.cdi.annotation-prefixes", []string{cdi.AnnotationPrefix}) // nvidia-ctk - tree.Set("nvidia-ctk.path", "nvidia-ctk") + tree.Set("nvidia-ctk.path", nvidiaCTKExecutable) return tree, nil } @@ -232,3 +238,33 @@ func getDistIDLike() []string { } return nil } + +// ResolveNVIDIACTKPath resolves the path to the nvidia-ctk binary. +// This executable is used in hooks and needs to be an absolute path. +// If the path is specified as an absolute path, it is used directly +// without checking for existence of an executable at that path. +func ResolveNVIDIACTKPath(logger *logrus.Logger, nvidiaCTKPath string) string { + if filepath.IsAbs(nvidiaCTKPath) { + logger.Debugf("Using specified NVIDIA Container Toolkit CLI path %v", nvidiaCTKPath) + return nvidiaCTKPath + } + + if nvidiaCTKPath == "" { + nvidiaCTKPath = nvidiaCTKExecutable + } + logger.Debugf("Locating NVIDIA Container Toolkit CLI as %v", nvidiaCTKPath) + lookup := lookup.NewExecutableLocator(logger, "") + hookPath := nvidiaCTKDefaultFilePath + targets, err := lookup.Locate(nvidiaCTKPath) + if err != nil { + logger.Warnf("Failed to locate %v: %v", nvidiaCTKPath, err) + } else if len(targets) == 0 { + logger.Warnf("%v not found", nvidiaCTKPath) + } else { + logger.Debugf("Found %v candidates: %v", nvidiaCTKPath, targets) + hookPath = targets[0] + } + logger.Debugf("Using NVIDIA Container Toolkit CLI path %v", hookPath) + + return hookPath +} diff --git a/internal/discover/discover.go b/internal/discover/discover.go index 0687055f..7ff9f042 100644 --- a/internal/discover/discover.go +++ b/internal/discover/discover.go @@ -16,12 +16,6 @@ package discover -// Config represents the configuration options for discovery -type Config struct { - DriverRoot string - NvidiaCTKPath string -} - // Device represents a discovered character device. type Device struct { HostPath string diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index dedbfbe7..527f7c1d 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -31,9 +31,7 @@ import ( ) // NewGraphicsDiscoverer returns the discoverer for graphics tools such as Vulkan. -func NewGraphicsDiscoverer(logger *logrus.Logger, devices image.VisibleDevices, cfg *Config) (Discover, error) { - driverRoot := cfg.DriverRoot - +func NewGraphicsDiscoverer(logger *logrus.Logger, devices image.VisibleDevices, driverRoot string, nvidiaCTKPath string) (Discover, error) { mounts, err := NewGraphicsMountsDiscoverer(logger, driverRoot) if err != nil { return nil, fmt.Errorf("failed to create mounts discoverer: %v", err) @@ -44,9 +42,9 @@ func NewGraphicsDiscoverer(logger *logrus.Logger, devices image.VisibleDevices, return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err) } - drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, cfg) + drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, driverRoot, nvidiaCTKPath) - xorg := optionalXorgDiscoverer(logger, driverRoot, cfg.NvidiaCTKPath) + xorg := optionalXorgDiscoverer(logger, driverRoot, nvidiaCTKPath) discover := Merge( Merge(drmDeviceNodes, drmByPathSymlinks), @@ -106,11 +104,11 @@ type drmDevicesByPath struct { } // newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer -func newCreateDRMByPathSymlinks(logger *logrus.Logger, devices Discover, cfg *Config) Discover { +func newCreateDRMByPathSymlinks(logger *logrus.Logger, devices Discover, driverRoot string, nvidiaCTKPath string) Discover { d := drmDevicesByPath{ logger: logger, - nvidiaCTKPath: FindNvidiaCTK(logger, cfg.NvidiaCTKPath), - driverRoot: cfg.DriverRoot, + nvidiaCTKPath: nvidiaCTKPath, + driverRoot: driverRoot, devicesFrom: devices, } @@ -300,7 +298,7 @@ func newXorgDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath s xorgHooks := xorgHooks{ libraries: xorgLibs, driverVersion: version, - nvidiaCTKPath: FindNvidiaCTK(logger, nvidiaCTKPath), + nvidiaCTKPath: nvidiaCTKPath, } xorgConfg := NewMounts( diff --git a/internal/discover/hooks.go b/internal/discover/hooks.go index 87202e7c..3135ea06 100644 --- a/internal/discover/hooks.go +++ b/internal/discover/hooks.go @@ -19,14 +19,7 @@ package discover import ( "path/filepath" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" - "github.com/sirupsen/logrus" -) - -const ( - nvidiaCTKExecutable = "nvidia-ctk" - nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk" ) var _ Discover = (*Hook)(nil) @@ -72,32 +65,3 @@ func CreateNvidiaCTKHook(nvidiaCTKPath string, hookName string, additionalArgs . Args: append([]string{filepath.Base(nvidiaCTKPath), "hook", hookName}, additionalArgs...), } } - -// FindNvidiaCTK locates the nvidia-ctk executable to be used in hooks. -// If an nvidia-ctk path is specified as an absolute path, it is used directly -// without checking for existence of an executable at that path. -func FindNvidiaCTK(logger *logrus.Logger, nvidiaCTKPath string) string { - if filepath.IsAbs(nvidiaCTKPath) { - logger.Debugf("Using specified NVIDIA Container Toolkit CLI path %v", nvidiaCTKPath) - return nvidiaCTKPath - } - - if nvidiaCTKPath == "" { - nvidiaCTKPath = nvidiaCTKExecutable - } - logger.Debugf("Locating NVIDIA Container Toolkit CLI as %v", nvidiaCTKPath) - lookup := lookup.NewExecutableLocator(logger, "") - hookPath := nvidiaCTKDefaultFilePath - targets, err := lookup.Locate(nvidiaCTKPath) - if err != nil { - logger.Warnf("Failed to locate %v: %v", nvidiaCTKPath, err) - } else if len(targets) == 0 { - logger.Warnf("%v not found", nvidiaCTKPath) - } else { - logger.Debugf("Found %v candidates: %v", nvidiaCTKPath, targets) - hookPath = targets[0] - } - logger.Debugf("Using NVIDIA Container Toolkit CLI path %v", hookPath) - - return hookPath -} diff --git a/internal/discover/ldconfig.go b/internal/discover/ldconfig.go index e56f605a..a1c237e0 100644 --- a/internal/discover/ldconfig.go +++ b/internal/discover/ldconfig.go @@ -25,10 +25,10 @@ import ( ) // NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified -func NewLDCacheUpdateHook(logger *logrus.Logger, mounts Discover, cfg *Config) (Discover, error) { +func NewLDCacheUpdateHook(logger *logrus.Logger, mounts Discover, nvidiaCTKPath string) (Discover, error) { d := ldconfig{ logger: logger, - nvidiaCTKPath: FindNvidiaCTK(logger, cfg.NvidiaCTKPath), + nvidiaCTKPath: nvidiaCTKPath, mountsFrom: mounts, } diff --git a/internal/discover/ldconfig_test.go b/internal/discover/ldconfig_test.go index 4b3c11a1..8d72dde6 100644 --- a/internal/discover/ldconfig_test.go +++ b/internal/discover/ldconfig_test.go @@ -31,11 +31,6 @@ const ( func TestLDCacheUpdateHook(t *testing.T) { logger, _ := testlog.NewNullLogger() - cfg := Config{ - DriverRoot: "/", - NvidiaCTKPath: testNvidiaCTKPath, - } - testCases := []struct { description string mounts []Mount @@ -95,7 +90,7 @@ func TestLDCacheUpdateHook(t *testing.T) { Lifecycle: "createContainer", } - d, err := NewLDCacheUpdateHook(logger, mountMock, &cfg) + d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCTKPath) require.NoError(t, err) hooks, err := d.Hooks() diff --git a/internal/discover/symlinks.go b/internal/discover/symlinks.go index b5d344e3..31e8e64e 100644 --- a/internal/discover/symlinks.go +++ b/internal/discover/symlinks.go @@ -33,10 +33,10 @@ type symlinks struct { } // NewCreateSymlinksHook creates a discoverer for a hook that creates required symlinks in the container -func NewCreateSymlinksHook(logger *logrus.Logger, csvFiles []string, mounts Discover, cfg *Config) (Discover, error) { +func NewCreateSymlinksHook(logger *logrus.Logger, csvFiles []string, mounts Discover, nvidiaCTKPath string) (Discover, error) { d := symlinks{ logger: logger, - nvidiaCTKPath: FindNvidiaCTK(logger, cfg.NvidiaCTKPath), + nvidiaCTKPath: nvidiaCTKPath, csvFiles: csvFiles, mountsFrom: mounts, } diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index e0f8582e..0f59b2f4 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -61,11 +61,6 @@ func NewCSVModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) } logger.Infof("Constructing modifier from config: %+v", *cfg) - config := &discover.Config{ - DriverRoot: cfg.NVIDIAContainerCLIConfig.Root, - NvidiaCTKPath: cfg.NVIDIACTKConfig.Path, - } - if err := checkRequirements(logger, image); err != nil { return nil, fmt.Errorf("requirements not met: %v", err) } @@ -79,17 +74,17 @@ func NewCSVModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) csvFiles = csv.BaseFilesOnly(csvFiles) } - csvDiscoverer, err := discover.NewFromCSVFiles(logger, csvFiles, config.DriverRoot) + csvDiscoverer, err := discover.NewFromCSVFiles(logger, csvFiles, cfg.NVIDIAContainerCLIConfig.Root) if err != nil { return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) } - createSymlinksHook, err := discover.NewCreateSymlinksHook(logger, csvFiles, csvDiscoverer, config) + createSymlinksHook, err := discover.NewCreateSymlinksHook(logger, csvFiles, csvDiscoverer, cfg.NVIDIACTKConfig.Path) if err != nil { return nil, fmt.Errorf("failed to create symlink hook discoverer: %v", err) } - ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(logger, csvDiscoverer, config) + ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(logger, csvDiscoverer, cfg.NVIDIACTKConfig.Path) if err != nil { return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err) } diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index 7e4fa8d1..ec4750d0 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -44,14 +44,11 @@ func NewGraphicsModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci. return nil, nil } - config := &discover.Config{ - DriverRoot: cfg.NVIDIAContainerCLIConfig.Root, - NvidiaCTKPath: cfg.NVIDIACTKConfig.Path, - } d, err := discover.NewGraphicsDiscoverer( logger, image.DevicesFromEnvvars(visibleDevicesEnvvar), - config, + cfg.NVIDIAContainerCLIConfig.Root, + cfg.NVIDIACTKConfig.Path, ) if err != nil { return nil, fmt.Errorf("failed to construct discoverer: %v", err) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 5634e75a..47a80e29 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -44,10 +44,6 @@ func (r rt) Run(argv []string) (rerr error) { if err != nil { return fmt.Errorf("error loading config: %v", err) } - if r.modeOverride != "" { - cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride - } - err = r.logger.Update( cfg.NVIDIAContainerRuntimeConfig.DebugFilePath, cfg.NVIDIAContainerRuntimeConfig.LogLevel, @@ -63,6 +59,13 @@ func (r rt) Run(argv []string) (rerr error) { r.logger.Reset() }() + // We apply some config updates here to ensure that the config is valid in + // all cases. + if r.modeOverride != "" { + cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride + } + cfg.NVIDIACTKConfig.Path = config.ResolveNVIDIACTKPath(r.logger.Logger, cfg.NVIDIACTKConfig.Path) + // Print the config to the output. configJSON, err := json.MarshalIndent(cfg, "", " ") if err == nil { diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index 408da55a..c0dc1445 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -86,11 +86,7 @@ func NewDriverLibraryDiscoverer(logger *logrus.Logger, driverRoot string, nvidia libraryPaths, ) - cfg := &discover.Config{ - DriverRoot: driverRoot, - NvidiaCTKPath: nvidiaCTKPath, - } - hooks, _ := discover.NewLDCacheUpdateHook(logger, libraries, cfg) + hooks, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCTKPath) d := discover.Merge( libraries, diff --git a/pkg/nvcdi/driver-wsl.go b/pkg/nvcdi/driver-wsl.go index cca4d52c..fbe0572f 100644 --- a/pkg/nvcdi/driver-wsl.go +++ b/pkg/nvcdi/driver-wsl.go @@ -90,11 +90,7 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi links := []string{fmt.Sprintf("%s::%s", target, link)} symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links) - cfg := &discover.Config{ - DriverRoot: driverRoot, - NvidiaCTKPath: nvidiaCTKPath, - } - ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, cfg) + ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCTKPath) d := discover.Merge( libraries,