Merge branch 'NVIDIA:main' into main

This commit is contained in:
Seungmin Kim 2024-06-02 17:42:21 +09:00 committed by GitHub
commit 0d70052105
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
81 changed files with 1870 additions and 797 deletions

View File

@ -33,7 +33,7 @@ variables:
# On the multi-arch builder we don't need the qemu setup.
SKIP_QEMU_SETUP: "1"
# Define the public staging registry
STAGING_REGISTRY: ghcr.io/nvidia/
STAGING_REGISTRY: ghcr.io/nvidia
STAGING_VERSION: ${CI_COMMIT_SHORT_SHA}
ARTIFACTORY_REPO_BASE: "https://urm.nvidia.com/artifactory/sw-gpu-cloudnative"
KITMAKER_RELEASE_FOLDER: "kitmaker"

View File

@ -1,5 +1,7 @@
# NVIDIA Container Toolkit Changelog
* Move `nvidia-ctk hook` commands to a separate `nvidia-cdi-hook` binary. The same subcommands are supported.
## v1.15.0
* Remove `nvidia-container-runtime` and `nvidia-docker2` packages.

View File

@ -0,0 +1,31 @@
# NVIDIA CDI Hook
The CLI `nvidia-cdi-hook` provides container device runtime hook capabilities when
called by a container runtime, as specific in a
[Container Device Interface](https://tags.cncf.io/container-device-interface/blob/main/SPEC.md)
file.
## Generating a CDI
The CDI itself is created for an NVIDIA-capable device using the
[`nvidia-ctk cdi generate`](../nvidia-ctk/) command.
When `nvidia-ctk cdi generate` is run, the CDI specification is generated as a yaml file.
The CDI specification provides instructions for a container runtime to set up devices, files and
other resources for the container prior to starting it. Those instructions
may include executing command-line tools to prepare the filesystem. The execution
of such command-line tools is called a hook.
`nvidia-cdi-hook` is the CLI tool that is expected to be called by the container runtime,
when specified by the CDI file.
See the [`nvidia-ctk` documentation](../nvidia-ctk/README.md) for more information
on generating a CDI file.
## Functionality
The `nvidia-cdi-hook` CLI provides the following functionality:
* `chmod` - Change the permissions of a file or directory inside the directory path to be mounted into a container.
* `create-symlinks` - Create symlinks inside the directory path to be mounted into a container.
* `update-ldcache` - Update the dynamic linker cache inside the directory path to be mounted into a container.

View File

@ -0,0 +1,36 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package commands
import (
"github.com/urfave/cli/v2"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
// New creates the commands associated with supported CDI hooks.
// These are shared by the nvidia-cdi-hook and nvidia-ctk hook commands.
func New(logger logger.Interface) []*cli.Command {
return []*cli.Command{
ldcache.NewCommand(logger),
symlinks.NewCommand(logger),
chmod.NewCommand(logger),
}
}

View File

@ -0,0 +1,93 @@
/**
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package main
import (
"os"
"github.com/sirupsen/logrus"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
cli "github.com/urfave/cli/v2"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/commands"
)
// options defines the options that can be set for the CLI through config files,
// environment variables, or command line flags
type options struct {
// Debug indicates whether the CLI is started in "debug" mode
Debug bool
// Quiet indicates whether the CLI is started in "quiet" mode
Quiet bool
}
func main() {
logger := logrus.New()
// Create a options struct to hold the parsed environment variables or command line flags
opts := options{}
// Create the top-level CLI
c := cli.NewApp()
c.Name = "NVIDIA CDI Hook"
c.UseShortOptionHandling = true
c.EnableBashCompletion = true
c.Usage = "Command to structure files for usage inside a container, called as hooks from a container runtime, defined in a CDI yaml file"
c.Version = info.GetVersionString()
// Setup the flags for this command
c.Flags = []cli.Flag{
&cli.BoolFlag{
Name: "debug",
Aliases: []string{"d"},
Usage: "Enable debug-level logging",
Destination: &opts.Debug,
EnvVars: []string{"NVIDIA_CDI_DEBUG"},
},
&cli.BoolFlag{
Name: "quiet",
Usage: "Suppress all output except for errors; overrides --debug",
Destination: &opts.Quiet,
EnvVars: []string{"NVIDIA_CDI_QUIET"},
},
}
// Set log-level for all subcommands
c.Before = func(c *cli.Context) error {
logLevel := logrus.InfoLevel
if opts.Debug {
logLevel = logrus.DebugLevel
}
if opts.Quiet {
logLevel = logrus.ErrorLevel
}
logger.SetLevel(logLevel)
return nil
}
// Define the subcommands
c.Commands = commands.New(logger)
// Run the CLI
err := c.Run(os.Args)
if err != nil {
logger.Errorf("%v", err)
os.Exit(1)
}
}

View File

@ -47,7 +47,7 @@ type options struct {
deviceNameStrategies cli.StringSlice
driverRoot string
devRoot string
nvidiaCTKPath string
nvidiaCDIHookPath string
ldconfigPath string
mode string
vendor string
@ -132,9 +132,12 @@ func (m command) build() *cli.Command {
Destination: &opts.librarySearchPaths,
},
&cli.StringFlag{
Name: "nvidia-ctk-path",
Usage: "Specify the path to use for the nvidia-ctk in the generated CDI specification. If this is left empty, the path will be searched.",
Destination: &opts.nvidiaCTKPath,
Name: "nvidia-cdi-hook-path",
Aliases: []string{"nvidia-ctk-path"},
Usage: "Specify the path to use for the nvidia-cdi-hook in the generated CDI specification. " +
"If not specified, the PATH will be searched for `nvidia-cdi-hook`. " +
"NOTE: That if this is specified as `nvidia-ctk`, the PATH will be searched for `nvidia-ctk` instead.",
Destination: &opts.nvidiaCDIHookPath,
},
&cli.StringFlag{
Name: "ldconfig-path",
@ -198,7 +201,7 @@ func (m command) validateFlags(c *cli.Context, opts *options) error {
}
}
opts.nvidiaCTKPath = config.ResolveNVIDIACTKPath(m.logger, opts.nvidiaCTKPath)
opts.nvidiaCDIHookPath = config.ResolveNVIDIACDIHookPath(m.logger, opts.nvidiaCDIHookPath)
if outputFileFormat := formatFromFilename(opts.output); outputFileFormat != "" {
m.logger.Debugf("Inferred output format as %q from output file name", outputFileFormat)
@ -262,7 +265,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
nvcdi.WithLogger(m.logger),
nvcdi.WithDriverRoot(opts.driverRoot),
nvcdi.WithDevRoot(opts.devRoot),
nvcdi.WithNVIDIACTKPath(opts.nvidiaCTKPath),
nvcdi.WithNVIDIACDIHookPath(opts.nvidiaCDIHookPath),
nvcdi.WithLdconfigPath(opts.ldconfigPath),
nvcdi.WithDeviceNamers(deviceNamers...),
nvcdi.WithMode(opts.mode),

View File

@ -38,7 +38,8 @@ type command struct {
// options stores the subcommand options
type options struct {
flags.Options
sets cli.StringSlice
setListSeparator string
sets cli.StringSlice
}
// NewCommand constructs an config command with the specified logger
@ -57,6 +58,9 @@ func (m command) build() *cli.Command {
c := cli.Command{
Name: "config",
Usage: "Interact with the NVIDIA Container Toolkit configuration",
Before: func(ctx *cli.Context) error {
return validateFlags(ctx, &opts)
},
Action: func(ctx *cli.Context) error {
return run(ctx, &opts)
},
@ -71,10 +75,21 @@ func (m command) build() *cli.Command {
Destination: &opts.Config,
},
&cli.StringSliceFlag{
Name: "set",
Usage: "Set a config value using the pattern key=value. If value is empty, this is equivalent to specifying the same key in unset. This flag can be specified multiple times",
Name: "set",
Usage: "Set a config value using the pattern 'key[=value]'. " +
"Specifying only 'key' is equivalent to 'key=true' for boolean settings. " +
"This flag can be specified multiple times, but only the last value for a specific " +
"config option is applied. " +
"If the setting represents a list, the elements are colon-separated.",
Destination: &opts.sets,
},
&cli.StringFlag{
Name: "set-list-separator",
Usage: "Specify a separator for lists applied using the set command.",
Hidden: true,
Value: ":",
Destination: &opts.setListSeparator,
},
&cli.BoolFlag{
Name: "in-place",
Aliases: []string{"i"},
@ -96,6 +111,13 @@ func (m command) build() *cli.Command {
return &c
}
func validateFlags(c *cli.Context, opts *options) error {
if opts.setListSeparator == "" {
return fmt.Errorf("set-list-separator must be set")
}
return nil
}
func run(c *cli.Context, opts *options) error {
cfgToml, err := config.New(
config.WithConfigFile(opts.Config),
@ -105,7 +127,7 @@ func run(c *cli.Context, opts *options) error {
}
for _, set := range opts.sets.Value() {
key, value, err := setFlagToKeyValue(set)
key, value, err := setFlagToKeyValue(set, opts.setListSeparator)
if err != nil {
return fmt.Errorf("invalid --set option %v: %w", set, err)
}
@ -139,7 +161,7 @@ var errInvalidFormat = errors.New("invalid format")
// setFlagToKeyValue converts a --set flag to a key-value pair.
// The set flag is of the form key[=value], with the value being optional if key refers to a
// boolean config option.
func setFlagToKeyValue(setFlag string) (string, interface{}, error) {
func setFlagToKeyValue(setFlag string, setListSeparator string) (string, interface{}, error) {
setParts := strings.SplitN(setFlag, "=", 2)
key := setParts[0]
@ -172,7 +194,7 @@ func setFlagToKeyValue(setFlag string) (string, interface{}, error) {
case reflect.String:
return key, value, nil
case reflect.Slice:
valueParts := strings.Split(value, ",")
valueParts := strings.Split(value, setListSeparator)
switch field.Elem().Kind() {
case reflect.String:
return key, valueParts, nil

View File

@ -25,11 +25,12 @@ import (
func TestSetFlagToKeyValue(t *testing.T) {
// TODO: We need to enable this test again since switching to reflect.
testCases := []struct {
description string
setFlag string
expectedKey string
expectedValue interface{}
expectedError error
description string
setFlag string
setListSeparator string
expectedKey string
expectedValue interface{}
expectedError error
}{
{
description: "option not present returns an error",
@ -106,22 +107,34 @@ func TestSetFlagToKeyValue(t *testing.T) {
expectedValue: []string{"string-value"},
},
{
description: "[]string option returns multiple values",
setFlag: "nvidia-container-cli.environment=first,second",
expectedKey: "nvidia-container-cli.environment",
expectedValue: []string{"first", "second"},
description: "[]string option returns multiple values",
setFlag: "nvidia-container-cli.environment=first,second",
setListSeparator: ",",
expectedKey: "nvidia-container-cli.environment",
expectedValue: []string{"first", "second"},
},
{
description: "[]string option returns values with equals",
setFlag: "nvidia-container-cli.environment=first=1,second=2",
expectedKey: "nvidia-container-cli.environment",
expectedValue: []string{"first=1", "second=2"},
description: "[]string option returns values with equals",
setFlag: "nvidia-container-cli.environment=first=1,second=2",
setListSeparator: ",",
expectedKey: "nvidia-container-cli.environment",
expectedValue: []string{"first=1", "second=2"},
},
{
description: "[]string option returns multiple values semi-colon",
setFlag: "nvidia-container-cli.environment=first;second",
setListSeparator: ";",
expectedKey: "nvidia-container-cli.environment",
expectedValue: []string{"first", "second"},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
k, v, err := setFlagToKeyValue(tc.setFlag)
if tc.setListSeparator == "" {
tc.setListSeparator = ","
}
k, v, err := setFlagToKeyValue(tc.setFlag, tc.setListSeparator)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, tc.expectedKey, k)
require.EqualValues(t, tc.expectedValue, v)

View File

@ -17,13 +17,10 @@
package hook
import (
chmod "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/hook/chmod"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/commands"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/urfave/cli/v2"
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/hook/create-symlinks"
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/hook/update-ldcache"
)
type hookCommand struct {
@ -46,11 +43,7 @@ func (m hookCommand) build() *cli.Command {
Usage: "A collection of hooks that may be injected into an OCI spec",
}
hook.Subcommands = []*cli.Command{
ldcache.NewCommand(m.logger),
symlinks.NewCommand(m.logger),
chmod.NewCommand(m.logger),
}
hook.Subcommands = commands.New(m.logger)
return &hook
}

View File

@ -17,6 +17,7 @@
package configure
import (
"encoding/json"
"fmt"
"path/filepath"
@ -66,6 +67,8 @@ type config struct {
mode string
hookFilePath string
runtimeConfigOverrideJSON string
nvidiaRuntime struct {
name string
path string
@ -153,6 +156,13 @@ func (m command) build() *cli.Command {
Usage: "Enable CDI in the configured runtime",
Destination: &config.cdi.enabled,
},
&cli.StringFlag{
Name: "runtime-config-override",
Destination: &config.runtimeConfigOverrideJSON,
Usage: "specify additional runtime options as a JSON string. The paths are relative to the runtime config.",
Value: "{}",
EnvVars: []string{"RUNTIME_CONFIG_OVERRIDE"},
},
}
return &configure
@ -194,6 +204,11 @@ func (m command) validateFlags(c *cli.Context, config *config) error {
config.cdi.enabled = false
}
if config.runtimeConfigOverrideJSON != "" && config.runtime != "containerd" {
m.logger.Warningf("Ignoring runtime-config-override flag for %v", config.runtime)
config.runtimeConfigOverrideJSON = ""
}
return nil
}
@ -237,10 +252,16 @@ func (m command) configureConfigFile(c *cli.Context, config *config) error {
return fmt.Errorf("unable to load config for runtime %v: %v", config.runtime, err)
}
runtimeConfigOverride, err := config.runtimeConfigOverride()
if err != nil {
return fmt.Errorf("unable to parse config overrides: %w", err)
}
err = cfg.AddRuntime(
config.nvidiaRuntime.name,
config.nvidiaRuntime.path,
config.nvidiaRuntime.setAsDefault,
runtimeConfigOverride,
)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
@ -293,6 +314,20 @@ func (c *config) getOuputConfigPath() string {
return c.resolveConfigFilePath()
}
// runtimeConfigOverride converts the specified runtimeConfigOverride JSON string to a map.
func (o *config) runtimeConfigOverride() (map[string]interface{}, error) {
if o.runtimeConfigOverrideJSON == "" {
return nil, nil
}
runtimeOptions := make(map[string]interface{})
if err := json.Unmarshal([]byte(o.runtimeConfigOverrideJSON), &runtimeOptions); err != nil {
return nil, fmt.Errorf("failed to read %v as JSON: %w", o.runtimeConfigOverrideJSON, err)
}
return runtimeOptions, nil
}
// configureOCIHook creates and configures the OCI hook for the NVIDIA runtime
func (m *command) configureOCIHook(c *cli.Context, config *config) error {
err := ocihook.CreateHook(config.hookFilePath, config.nvidiaRuntime.hookPath)

4
go.mod
View File

@ -3,8 +3,8 @@ module github.com/NVIDIA/nvidia-container-toolkit
go 1.20
require (
github.com/NVIDIA/go-nvlib v0.3.0
github.com/NVIDIA/go-nvml v0.12.0-5
github.com/NVIDIA/go-nvlib v0.5.0
github.com/NVIDIA/go-nvml v0.12.0-6
github.com/fsnotify/fsnotify v1.7.0
github.com/opencontainers/runtime-spec v1.2.0
github.com/pelletier/go-toml v1.9.5

8
go.sum
View File

@ -1,7 +1,7 @@
github.com/NVIDIA/go-nvlib v0.3.0 h1:vd7jSOthJTqzqIWZrv317xDr1+Mnjoy5X4N69W9YwQM=
github.com/NVIDIA/go-nvlib v0.3.0/go.mod h1:NasUuId9hYFvwzuOHCu9F2X6oTU2tG0JHTfbJYuDAbA=
github.com/NVIDIA/go-nvml v0.12.0-5 h1:4DYsngBqJEAEj+/RFmBZ43Q3ymoR3tyS0oBuJk12Fag=
github.com/NVIDIA/go-nvml v0.12.0-5/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ=
github.com/NVIDIA/go-nvlib v0.5.0 h1:951KGrfr+p3cs89alO9z/ZxPPWKxwht9tx9rxiADoLI=
github.com/NVIDIA/go-nvlib v0.5.0/go.mod h1:87z49ULPr4GWPSGfSIp3taU4XENRYN/enIg88MzcL4k=
github.com/NVIDIA/go-nvml v0.12.0-6 h1:FJYc2KrpvX+VOC/8QQvMiQMmZ/nPMRpdJO/Ik4xfcr0=
github.com/NVIDIA/go-nvml v0.12.0-6/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=

View File

@ -33,8 +33,9 @@ const (
configOverride = "XDG_CONFIG_HOME"
configFilePath = "nvidia-container-runtime/config.toml"
nvidiaCTKExecutable = "nvidia-ctk"
nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk"
nvidiaCTKExecutable = "nvidia-ctk"
nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk"
nvidiaCDIHookDefaultFilePath = "/usr/bin/nvidia-cdi-hook"
nvidiaContainerRuntimeHookExecutable = "nvidia-container-runtime-hook"
nvidiaContainerRuntimeHookDefaultPath = "/usr/bin/nvidia-container-runtime-hook"
@ -177,6 +178,8 @@ var getDistIDLike = func() []string {
// This executable is used in hooks and needs to be an absolute path.
// If the path is specified as an absolute path, it is used directly
// without checking for existence of an executable at that path.
//
// Deprecated: Use ResolveNVIDIACDIHookPath directly instead.
func ResolveNVIDIACTKPath(logger logger.Interface, nvidiaCTKPath string) string {
return resolveWithDefault(
logger,
@ -186,6 +189,27 @@ func ResolveNVIDIACTKPath(logger logger.Interface, nvidiaCTKPath string) string
)
}
// ResolveNVIDIACDIHookPath resolves the path to the nvidia-cdi-hook binary.
// This executable is used in hooks and needs to be an absolute path.
// If the path is specified as an absolute path, it is used directly
// without checking for existence of an executable at that path.
func ResolveNVIDIACDIHookPath(logger logger.Interface, nvidiaCDIHookPath string) string {
if filepath.Base(nvidiaCDIHookPath) == "nvidia-ctk" {
return resolveWithDefault(
logger,
"NVIDIA Container Toolkit CLI",
nvidiaCDIHookPath,
nvidiaCTKDefaultFilePath,
)
}
return resolveWithDefault(
logger,
"NVIDIA CDI Hook CLI",
nvidiaCDIHookPath,
nvidiaCDIHookDefaultFilePath,
)
}
// ResolveNVIDIAContainerRuntimeHookPath resolves the path the nvidia-container-runtime-hook binary.
func ResolveNVIDIAContainerRuntimeHookPath(logger logger.Interface, nvidiaContainerRuntimeHookPath string) string {
return resolveWithDefault(

View File

@ -36,20 +36,20 @@ import (
// TODO: The logic for creating DRM devices should be consolidated between this
// and the logic for generating CDI specs for a single device. This is only used
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCTKPath string) (Discover, error) {
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCDIHookPath string) (Discover, error) {
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
}
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCTKPath)
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCDIHookPath)
discover := Merge(drmDeviceNodes, drmByPathSymlinks)
return discover, nil
}
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) {
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) {
libraries := NewMounts(
logger,
driver.Libraries(),
@ -77,7 +77,7 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, n
},
)
xorg := optionalXorgDiscoverer(logger, driver, nvidiaCTKPath)
xorg := optionalXorgDiscoverer(logger, driver, nvidiaCDIHookPath)
discover := Merge(
libraries,
@ -90,19 +90,19 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, n
type drmDevicesByPath struct {
None
logger logger.Interface
nvidiaCTKPath string
devRoot string
devicesFrom Discover
logger logger.Interface
nvidiaCDIHookPath string
devRoot string
devicesFrom Discover
}
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCTKPath string) Discover {
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCDIHookPath string) Discover {
d := drmDevicesByPath{
logger: logger,
nvidiaCTKPath: nvidiaCTKPath,
devRoot: devRoot,
devicesFrom: devices,
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
devRoot: devRoot,
devicesFrom: devices,
}
return &d
@ -130,8 +130,8 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
args = append(args, "--link", l)
}
hook := CreateNvidiaCTKHook(
d.nvidiaCTKPath,
hook := CreateNvidiaCDIHook(
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
@ -236,17 +236,17 @@ func newDRMDeviceFilter(devices image.VisibleDevices, devRoot string) (Filter, e
}
type xorgHooks struct {
libraries Discover
driverVersion string
nvidiaCTKPath string
libraries Discover
driverVersion string
nvidiaCDIHookPath string
}
var _ Discover = (*xorgHooks)(nil)
// optionalXorgDiscoverer creates a discoverer for Xorg libraries.
// If the creation of the discoverer fails, a None discoverer is returned.
func optionalXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) Discover {
xorg, err := newXorgDiscoverer(logger, driver, nvidiaCTKPath)
func optionalXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) Discover {
xorg, err := newXorgDiscoverer(logger, driver, nvidiaCDIHookPath)
if err != nil {
logger.Warningf("Failed to create Xorg discoverer: %v; skipping xorg libraries", err)
return None{}
@ -254,7 +254,7 @@ func optionalXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidia
return xorg
}
func newXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string) (Discover, error) {
func newXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) {
libCudaPaths, err := cuda.New(
driver.Libraries(),
).Locate(".*.*")
@ -304,9 +304,9 @@ func newXorgDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPa
},
)
xorgHooks := xorgHooks{
libraries: xorgLibs,
driverVersion: version,
nvidiaCTKPath: nvidiaCTKPath,
libraries: xorgLibs,
driverVersion: version,
nvidiaCDIHookPath: nvidiaCDIHookPath,
}
xorgConfig := NewMounts(
@ -358,7 +358,7 @@ func (m xorgHooks) Hooks() ([]Hook, error) {
link := strings.TrimSuffix(target, "."+m.driverVersion)
links := []string{fmt.Sprintf("%s::%s", filepath.Base(target), link)}
symlinkHook := CreateCreateSymlinkHook(
m.nvidiaCTKPath,
m.nvidiaCDIHookPath,
links,
)

View File

@ -41,7 +41,7 @@ func (h Hook) Hooks() ([]Hook, error) {
}
// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target.
func CreateCreateSymlinkHook(nvidiaCTKPath string, links []string) Discover {
func CreateCreateSymlinkHook(nvidiaCDIHookPath string, links []string) Discover {
if len(links) == 0 {
return None{}
}
@ -50,18 +50,31 @@ func CreateCreateSymlinkHook(nvidiaCTKPath string, links []string) Discover {
for _, link := range links {
args = append(args, "--link", link)
}
return CreateNvidiaCTKHook(
nvidiaCTKPath,
return CreateNvidiaCDIHook(
nvidiaCDIHookPath,
"create-symlinks",
args...,
)
}
// CreateNvidiaCTKHook creates a hook which invokes the NVIDIA Container CLI hook subcommand.
func CreateNvidiaCTKHook(nvidiaCTKPath string, hookName string, additionalArgs ...string) Hook {
// CreateNvidiaCDIHook creates a hook which invokes the NVIDIA Container CLI hook subcommand.
func CreateNvidiaCDIHook(nvidiaCDIHookPath string, hookName string, additionalArgs ...string) Hook {
return cdiHook(nvidiaCDIHookPath).Create(hookName, additionalArgs...)
}
type cdiHook string
func (c cdiHook) Create(name string, args ...string) Hook {
return Hook{
Lifecycle: cdi.CreateContainerHook,
Path: nvidiaCTKPath,
Args: append([]string{filepath.Base(nvidiaCTKPath), "hook", hookName}, additionalArgs...),
Path: string(c),
Args: append(c.requiredArgs(name), args...),
}
}
func (c cdiHook) requiredArgs(name string) []string {
base := filepath.Base(string(c))
if base == "nvidia-ctk" {
return []string{base, "hook", name}
}
return []string{base, name}
}

View File

@ -25,12 +25,12 @@ import (
)
// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCTKPath, ldconfigPath string) (Discover, error) {
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHookPath, ldconfigPath string) (Discover, error) {
d := ldconfig{
logger: logger,
nvidiaCTKPath: nvidiaCTKPath,
ldconfigPath: ldconfigPath,
mountsFrom: mounts,
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
ldconfigPath: ldconfigPath,
mountsFrom: mounts,
}
return &d, nil
@ -38,10 +38,10 @@ func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCTKPat
type ldconfig struct {
None
logger logger.Interface
nvidiaCTKPath string
ldconfigPath string
mountsFrom Discover
logger logger.Interface
nvidiaCDIHookPath string
ldconfigPath string
mountsFrom Discover
}
// Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths.
@ -51,7 +51,7 @@ func (d ldconfig) Hooks() ([]Hook, error) {
return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err)
}
h := CreateLDCacheUpdateHook(
d.nvidiaCTKPath,
d.nvidiaCDIHookPath,
d.ldconfigPath,
getLibraryPaths(mounts),
)
@ -70,7 +70,7 @@ func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []str
args = append(args, "--folder", f)
}
hook := CreateNvidiaCTKHook(
hook := CreateNvidiaCDIHook(
executable,
"update-ldcache",
args...,

View File

@ -25,8 +25,8 @@ import (
)
const (
testNvidiaCTKPath = "/foo/bar/nvidia-ctk"
testLdconfigPath = "/bar/baz/ldconfig"
testNvidiaCDIHookPath = "/foo/bar/nvidia-cdi-hook"
testLdconfigPath = "/bar/baz/ldconfig"
)
func TestLDCacheUpdateHook(t *testing.T) {
@ -42,7 +42,7 @@ func TestLDCacheUpdateHook(t *testing.T) {
}{
{
description: "empty mounts",
expectedArgs: []string{"nvidia-ctk", "hook", "update-ldcache"},
expectedArgs: []string{"nvidia-cdi-hook", "update-ldcache"},
},
{
description: "mount error",
@ -65,7 +65,7 @@ func TestLDCacheUpdateHook(t *testing.T) {
Path: "/usr/local/lib/libbar.so",
},
},
expectedArgs: []string{"nvidia-ctk", "hook", "update-ldcache", "--folder", "/usr/local/lib", "--folder", "/usr/local/libother"},
expectedArgs: []string{"nvidia-cdi-hook", "update-ldcache", "--folder", "/usr/local/lib", "--folder", "/usr/local/libother"},
},
{
description: "host paths are ignored",
@ -75,12 +75,12 @@ func TestLDCacheUpdateHook(t *testing.T) {
Path: "/usr/local/lib/libfoo.so",
},
},
expectedArgs: []string{"nvidia-ctk", "hook", "update-ldcache", "--folder", "/usr/local/lib"},
expectedArgs: []string{"nvidia-cdi-hook", "update-ldcache", "--folder", "/usr/local/lib"},
},
{
description: "explicit ldconfig path is passed",
ldconfigPath: testLdconfigPath,
expectedArgs: []string{"nvidia-ctk", "hook", "update-ldcache", "--ldconfig-path", testLdconfigPath},
expectedArgs: []string{"nvidia-cdi-hook", "update-ldcache", "--ldconfig-path", testLdconfigPath},
},
}
@ -92,12 +92,12 @@ func TestLDCacheUpdateHook(t *testing.T) {
},
}
expectedHook := Hook{
Path: testNvidiaCTKPath,
Path: testNvidiaCDIHookPath,
Args: tc.expectedArgs,
Lifecycle: "createContainer",
}
d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCTKPath, tc.ldconfigPath)
d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCDIHookPath, tc.ldconfigPath)
require.NoError(t, err)
hooks, err := d.Hooks()

View File

@ -243,7 +243,7 @@ func TestUsesNVGPUModule(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
sut := additionalInfo{
nvmllib: tc.nvmllib,
devicelib: device.New(device.WithNvml(tc.nvmllib)),
devicelib: device.New(tc.nvmllib),
}
flag, _ := sut.UsesNVGPUModule()

View File

@ -17,75 +17,40 @@
package info
import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
// infoInterface provides an alias for mocking.
//
//go:generate moq -stub -out info-interface_mock.go . infoInterface
type infoInterface interface {
info.Interface
// UsesNVGPUModule indicates whether the system is using the nvgpu kernel module
UsesNVGPUModule() (bool, string)
}
type resolver struct {
logger logger.Interface
info infoInterface
}
// ResolveAutoMode determines the correct mode for the platform if set to "auto"
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) {
nvinfo := info.New()
nvmllib := nvml.New()
devicelib := device.New(
device.WithNvml(nvmllib),
)
info := additionalInfo{
Interface: nvinfo,
nvmllib: nvmllib,
devicelib: devicelib,
}
r := resolver{
logger: logger,
info: info,
}
return r.resolveMode(mode, image)
return resolveMode(logger, mode, image, nil)
}
// resolveMode determines the correct mode for the platform if set to "auto"
func (r resolver) resolveMode(mode string, image image.CUDA) (rmode string) {
func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) {
if mode != "auto" {
r.logger.Infof("Using requested mode '%s'", mode)
logger.Infof("Using requested mode '%s'", mode)
return mode
}
defer func() {
r.logger.Infof("Auto-detected mode as '%v'", rmode)
logger.Infof("Auto-detected mode as '%v'", rmode)
}()
if image.OnlyFullyQualifiedCDIDevices() {
return "cdi"
}
isTegra, reason := r.info.IsTegraSystem()
r.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason)
nvinfo := info.New(
info.WithLogger(logger),
info.WithPropertyExtractor(propertyExtractor),
)
hasNVML, reason := r.info.HasNvml()
r.logger.Debugf("Has NVML? %v: %v", hasNVML, reason)
usesNVGPUModule, reason := r.info.UsesNVGPUModule()
r.logger.Debugf("Uses nvgpu kernel module? %v: %v", usesNVGPUModule, reason)
if (isTegra && !hasNVML) || usesNVGPUModule {
switch nvinfo.ResolvePlatform() {
case info.PlatformNVML, info.PlatformWSL:
return "legacy"
case info.PlatformTegra:
return "csv"
}
return "legacy"
}

View File

@ -19,6 +19,7 @@ package info
import (
"testing"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/opencontainers/runtime-spec/specs-go"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
@ -202,23 +203,24 @@ func TestResolveAutoMode(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
info := &infoInterfaceMock{
properties := &info.PropertyExtractorMock{
HasNvmlFunc: func() (bool, string) {
return tc.info["nvml"], "nvml"
},
HasDXCoreFunc: func() (bool, string) {
return tc.info["dxcore"], "dxcore"
},
IsTegraSystemFunc: func() (bool, string) {
return tc.info["tegra"], "tegra"
},
UsesNVGPUModuleFunc: func() (bool, string) {
HasTegraFilesFunc: func() (bool, string) {
return tc.info["tegra"], "tegra"
},
UsesOnlyNVGPUModuleFunc: func() (bool, string) {
return tc.info["nvgpu"], "nvgpu"
},
}
r := resolver{
logger: logger,
info: info,
}
var mounts []specs.Mount
for _, d := range tc.mounts {
mount := specs.Mount{
@ -231,7 +233,7 @@ func TestResolveAutoMode(t *testing.T) {
image.WithEnvMap(tc.envmap),
image.WithMounts(mounts),
)
mode := r.resolveMode(tc.mode, image)
mode := resolveMode(logger, tc.mode, image, properties)
require.EqualValues(t, tc.expectedMode, mode)
})
}

View File

@ -1,194 +0,0 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package info
import (
"sync"
)
// Ensure, that infoInterfaceMock does implement infoInterface.
// If this is not the case, regenerate this file with moq.
var _ infoInterface = &infoInterfaceMock{}
// infoInterfaceMock is a mock implementation of infoInterface.
//
// func TestSomethingThatUsesinfoInterface(t *testing.T) {
//
// // make and configure a mocked infoInterface
// mockedinfoInterface := &infoInterfaceMock{
// HasDXCoreFunc: func() (bool, string) {
// panic("mock out the HasDXCore method")
// },
// HasNvmlFunc: func() (bool, string) {
// panic("mock out the HasNvml method")
// },
// IsTegraSystemFunc: func() (bool, string) {
// panic("mock out the IsTegraSystem method")
// },
// UsesNVGPUModuleFunc: func() (bool, string) {
// panic("mock out the UsesNVGPUModule method")
// },
// }
//
// // use mockedinfoInterface in code that requires infoInterface
// // and then make assertions.
//
// }
type infoInterfaceMock struct {
// HasDXCoreFunc mocks the HasDXCore method.
HasDXCoreFunc func() (bool, string)
// HasNvmlFunc mocks the HasNvml method.
HasNvmlFunc func() (bool, string)
// IsTegraSystemFunc mocks the IsTegraSystem method.
IsTegraSystemFunc func() (bool, string)
// UsesNVGPUModuleFunc mocks the UsesNVGPUModule method.
UsesNVGPUModuleFunc func() (bool, string)
// calls tracks calls to the methods.
calls struct {
// HasDXCore holds details about calls to the HasDXCore method.
HasDXCore []struct {
}
// HasNvml holds details about calls to the HasNvml method.
HasNvml []struct {
}
// IsTegraSystem holds details about calls to the IsTegraSystem method.
IsTegraSystem []struct {
}
// UsesNVGPUModule holds details about calls to the UsesNVGPUModule method.
UsesNVGPUModule []struct {
}
}
lockHasDXCore sync.RWMutex
lockHasNvml sync.RWMutex
lockIsTegraSystem sync.RWMutex
lockUsesNVGPUModule sync.RWMutex
}
// HasDXCore calls HasDXCoreFunc.
func (mock *infoInterfaceMock) HasDXCore() (bool, string) {
callInfo := struct {
}{}
mock.lockHasDXCore.Lock()
mock.calls.HasDXCore = append(mock.calls.HasDXCore, callInfo)
mock.lockHasDXCore.Unlock()
if mock.HasDXCoreFunc == nil {
var (
bOut bool
sOut string
)
return bOut, sOut
}
return mock.HasDXCoreFunc()
}
// HasDXCoreCalls gets all the calls that were made to HasDXCore.
// Check the length with:
//
// len(mockedinfoInterface.HasDXCoreCalls())
func (mock *infoInterfaceMock) HasDXCoreCalls() []struct {
} {
var calls []struct {
}
mock.lockHasDXCore.RLock()
calls = mock.calls.HasDXCore
mock.lockHasDXCore.RUnlock()
return calls
}
// HasNvml calls HasNvmlFunc.
func (mock *infoInterfaceMock) HasNvml() (bool, string) {
callInfo := struct {
}{}
mock.lockHasNvml.Lock()
mock.calls.HasNvml = append(mock.calls.HasNvml, callInfo)
mock.lockHasNvml.Unlock()
if mock.HasNvmlFunc == nil {
var (
bOut bool
sOut string
)
return bOut, sOut
}
return mock.HasNvmlFunc()
}
// HasNvmlCalls gets all the calls that were made to HasNvml.
// Check the length with:
//
// len(mockedinfoInterface.HasNvmlCalls())
func (mock *infoInterfaceMock) HasNvmlCalls() []struct {
} {
var calls []struct {
}
mock.lockHasNvml.RLock()
calls = mock.calls.HasNvml
mock.lockHasNvml.RUnlock()
return calls
}
// IsTegraSystem calls IsTegraSystemFunc.
func (mock *infoInterfaceMock) IsTegraSystem() (bool, string) {
callInfo := struct {
}{}
mock.lockIsTegraSystem.Lock()
mock.calls.IsTegraSystem = append(mock.calls.IsTegraSystem, callInfo)
mock.lockIsTegraSystem.Unlock()
if mock.IsTegraSystemFunc == nil {
var (
bOut bool
sOut string
)
return bOut, sOut
}
return mock.IsTegraSystemFunc()
}
// IsTegraSystemCalls gets all the calls that were made to IsTegraSystem.
// Check the length with:
//
// len(mockedinfoInterface.IsTegraSystemCalls())
func (mock *infoInterfaceMock) IsTegraSystemCalls() []struct {
} {
var calls []struct {
}
mock.lockIsTegraSystem.RLock()
calls = mock.calls.IsTegraSystem
mock.lockIsTegraSystem.RUnlock()
return calls
}
// UsesNVGPUModule calls UsesNVGPUModuleFunc.
func (mock *infoInterfaceMock) UsesNVGPUModule() (bool, string) {
callInfo := struct {
}{}
mock.lockUsesNVGPUModule.Lock()
mock.calls.UsesNVGPUModule = append(mock.calls.UsesNVGPUModule, callInfo)
mock.lockUsesNVGPUModule.Unlock()
if mock.UsesNVGPUModuleFunc == nil {
var (
bOut bool
sOut string
)
return bOut, sOut
}
return mock.UsesNVGPUModuleFunc()
}
// UsesNVGPUModuleCalls gets all the calls that were made to UsesNVGPUModule.
// Check the length with:
//
// len(mockedinfoInterface.UsesNVGPUModuleCalls())
func (mock *infoInterfaceMock) UsesNVGPUModuleCalls() []struct {
} {
var calls []struct {
}
mock.lockUsesNVGPUModule.RLock()
calls = mock.calls.UsesNVGPUModule
mock.lockUsesNVGPUModule.RUnlock()
return calls
}

View File

@ -185,7 +185,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) {
cdilib, err := nvcdi.New(
nvcdi.WithLogger(logger),
nvcdi.WithNVIDIACTKPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
nvcdi.WithVendor("runtime.nvidia.com"),
nvcdi.WithClass("gpu"),

View File

@ -62,7 +62,7 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, image image.CUD
cdilib, err := nvcdi.New(
nvcdi.WithLogger(logger),
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
nvcdi.WithNVIDIACTKPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithMode(nvcdi.ModeCSV),
nvcdi.WithCSVFiles(csvFiles),
)

View File

@ -35,12 +35,12 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image imag
return nil, nil
}
nvidiaCTKPath := cfg.NVIDIACTKConfig.Path
nvidiaCDIHookPath := cfg.NVIDIACTKConfig.Path
mounts, err := discover.NewGraphicsMountsDiscoverer(
logger,
driver,
nvidiaCTKPath,
nvidiaCDIHookPath,
)
if err != nil {
return nil, fmt.Errorf("failed to create mounts discoverer: %v", err)
@ -52,7 +52,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, image imag
logger,
image.DevicesFromEnvvars(visibleDevicesEnvvar),
devRoot,
nvidiaCTKPath,
nvidiaCDIHookPath,
)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer: %v", err)

View File

@ -90,10 +90,9 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk",
"hook",
"nvidia-cdi-hook",
"create-symlinks",
"--link",
"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so::/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so",
@ -147,10 +146,9 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk",
"hook",
"nvidia-cdi-hook",
"create-symlinks",
"--link",
"/usr/lib/aarch64-linux-gnu/tegra/libv4l2_nvargus.so::/usr/lib/aarch64-linux-gnu/libv4l/plugins/nv/libv4l2_nvargus.so",
@ -189,7 +187,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
o := tegraOptions{
logger: logger,
nvidiaCTKPath: "/usr/bin/nvidia-ctk",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
csvFiles: []string{"dummy"},
ignorePatterns: tc.ignorePatterns,
symlinkLocator: tc.symlinkLocator,

View File

@ -28,10 +28,10 @@ import (
type symlinkHook struct {
discover.None
logger logger.Interface
nvidiaCTKPath string
targets []string
mountsFrom discover.Discover
logger logger.Interface
nvidiaCDIHookPath string
targets []string
mountsFrom discover.Discover
// The following can be overridden for testing
symlinkChainLocator lookup.Locator
@ -42,7 +42,7 @@ type symlinkHook struct {
func (o tegraOptions) createCSVSymlinkHooks(targets []string, mounts discover.Discover) discover.Discover {
return symlinkHook{
logger: o.logger,
nvidiaCTKPath: o.nvidiaCTKPath,
nvidiaCDIHookPath: o.nvidiaCDIHookPath,
targets: targets,
mountsFrom: mounts,
symlinkChainLocator: o.symlinkChainLocator,
@ -60,7 +60,7 @@ func (d symlinkHook) Hooks() ([]discover.Hook, error) {
csvSymlinks := d.getCSVFileSymlinks()
return discover.CreateCreateSymlinkHook(
d.nvidiaCTKPath,
d.nvidiaCDIHookPath,
append(csvSymlinks, specificLinks...),
).Hooks()
}

View File

@ -30,7 +30,7 @@ type tegraOptions struct {
csvFiles []string
driverRoot string
devRoot string
nvidiaCTKPath string
nvidiaCDIHookPath string
ldconfigPath string
librarySearchPaths []string
ignorePatterns ignoreMountSpecPatterns
@ -80,7 +80,7 @@ func New(opts ...Option) (discover.Discover, error) {
return nil, fmt.Errorf("failed to create CSV discoverer: %v", err)
}
ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.nvidiaCTKPath, o.ldconfigPath)
ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.nvidiaCDIHookPath, o.ldconfigPath)
if err != nil {
return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err)
}
@ -133,10 +133,10 @@ func WithCSVFiles(csvFiles []string) Option {
}
}
// WithNVIDIACTKPath sets the path to the nvidia-container-toolkit binary.
func WithNVIDIACTKPath(nvidiaCTKPath string) Option {
// WithNVIDIACDIHookPath sets the path to the nvidia-cdi-hook binary.
func WithNVIDIACDIHookPath(nvidiaCDIHookPath string) Option {
return func(o *tegraOptions) {
o.nvidiaCTKPath = nvidiaCTKPath
o.nvidiaCDIHookPath = nvidiaCDIHookPath
}
}

View File

@ -66,6 +66,7 @@ func (r rt) Run(argv []string) (rerr error) {
if r.modeOverride != "" {
cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride
}
//nolint:staticcheck // TODO(elezar): We should swith the nvidia-container-runtime from using nvidia-ctk to using nvidia-cdi-hook.
cfg.NVIDIACTKConfig.Path = config.ResolveNVIDIACTKPath(r.logger, cfg.NVIDIACTKConfig.Path)
cfg.NVIDIAContainerRuntimeHookConfig.Path = config.ResolveNVIDIAContainerRuntimeHookPath(r.logger, cfg.NVIDIAContainerRuntimeHookConfig.Path)

View File

@ -1,2 +1,3 @@
nvidia-container-runtime /usr/bin
nvidia-ctk /usr/bin
nvidia-cdi-hook /usr/bin

View File

@ -16,6 +16,7 @@ Source2: LICENSE
Source3: nvidia-container-runtime
Source4: nvidia-container-runtime.cdi
Source5: nvidia-container-runtime.legacy
Source6: nvidia-cdi-hook
Obsoletes: nvidia-container-runtime <= 3.5.0-1, nvidia-container-runtime-hook <= 1.4.0-2
Provides: nvidia-container-runtime
@ -27,7 +28,7 @@ Requires: nvidia-container-toolkit-base == %{version}-%{release}
Provides tools and utilities to enable GPU support in containers.
%prep
cp %{SOURCE0} %{SOURCE1} %{SOURCE2} %{SOURCE3} %{SOURCE4} %{SOURCE5} .
cp %{SOURCE0} %{SOURCE1} %{SOURCE2} %{SOURCE3} %{SOURCE4} %{SOURCE5} %{SOURCE6} .
%install
mkdir -p %{buildroot}%{_bindir}
@ -36,6 +37,7 @@ install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime
install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime.cdi
install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime.legacy
install -m 755 -t %{buildroot}%{_bindir} nvidia-ctk
install -m 755 -t %{buildroot}%{_bindir} nvidia-cdi-hook
%post
if [ $1 -gt 1 ]; then # only on package upgrade
@ -86,6 +88,7 @@ Provides tools such as the NVIDIA Container Runtime and NVIDIA Container Toolkit
%license LICENSE
%{_bindir}/nvidia-container-runtime
%{_bindir}/nvidia-ctk
%{_bindir}/nvidia-cdi-hook
# The OPERATOR EXTENSIONS package consists of components that are required to enable GPU support in Kubernetes.
# This package is not distributed as part of the NVIDIA Container Toolkit RPMs.

View File

@ -19,7 +19,7 @@ package engine
// Interface defines the API for a runtime config updater.
type Interface interface {
DefaultRuntime() string
AddRuntime(string, string, bool) error
AddRuntime(string, string, bool, ...map[string]interface{}) error
Set(string, interface{})
RemoveRuntime(string) error
Save(string) (int64, error)

View File

@ -30,7 +30,7 @@ type ConfigV1 Config
var _ engine.Interface = (*ConfigV1)(nil)
// AddRuntime adds a runtime to the containerd config
func (c *ConfigV1) AddRuntime(name string, path string, setAsDefault bool) error {
func (c *ConfigV1) AddRuntime(name string, path string, setAsDefault bool, configOverrides ...map[string]interface{}) error {
if c == nil || c.Tree == nil {
return fmt.Errorf("config is nil")
}
@ -75,6 +75,16 @@ func (c *ConfigV1) AddRuntime(name string, path string, setAsDefault bool) error
}
config.SetPath([]string{"plugins", "cri", "containerd", "default_runtime", "options", "BinaryName"}, path)
config.SetPath([]string{"plugins", "cri", "containerd", "default_runtime", "options", "Runtime"}, path)
defaultRuntimeSubtree := subtreeAtPath(config, "plugins", "cri", "containerd", "default_runtime")
if err := defaultRuntimeSubtree.applyOverrides(configOverrides...); err != nil {
return fmt.Errorf("failed to apply config overrides to default_runtime: %w", err)
}
}
runtimeSubtree := subtreeAtPath(config, "plugins", "cri", "containerd", "runtimes", name)
if err := runtimeSubtree.applyOverrides(configOverrides...); err != nil {
return fmt.Errorf("failed to apply config overrides: %w", err)
}
*c.Tree = config

View File

@ -25,7 +25,7 @@ import (
)
// AddRuntime adds a runtime to the containerd config
func (c *Config) AddRuntime(name string, path string, setAsDefault bool) error {
func (c *Config) AddRuntime(name string, path string, setAsDefault bool, configOverrides ...map[string]interface{}) error {
if c == nil || c.Tree == nil {
return fmt.Errorf("config is nil")
}
@ -60,6 +60,11 @@ func (c *Config) AddRuntime(name string, path string, setAsDefault bool) error {
config.SetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "containerd", "default_runtime_name"}, name)
}
runtimeSubtree := subtreeAtPath(config, "plugins", "io.containerd.grpc.v1.cri", "containerd", "runtimes", name)
if err := runtimeSubtree.applyOverrides(configOverrides...); err != nil {
return fmt.Errorf("failed to apply config overrides: %w", err)
}
*c.Tree = config
return nil
}

View File

@ -0,0 +1,97 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package containerd
import (
"testing"
"github.com/pelletier/go-toml"
"github.com/stretchr/testify/require"
)
func TestAddRuntime(t *testing.T) {
testCases := []struct {
description string
config string
setAsDefault bool
configOverrides []map[string]interface{}
expectedConfig string
expectedError error
}{
{
description: "empty config not default runtime",
expectedConfig: `
version = 2
[plugins]
[plugins."io.containerd.grpc.v1.cri"]
[plugins."io.containerd.grpc.v1.cri".containerd]
[plugins."io.containerd.grpc.v1.cri".containerd.runtimes]
[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.test]
privileged_without_host_devices = false
runtime_engine = ""
runtime_root = ""
runtime_type = ""
[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.test.options]
BinaryName = "/usr/bin/test"
`,
expectedError: nil,
},
{
description: "empty config not default runtime with overrides",
configOverrides: []map[string]interface{}{
{
"options": map[string]interface{}{
"SystemdCgroup": true,
},
},
},
expectedConfig: `
version = 2
[plugins]
[plugins."io.containerd.grpc.v1.cri"]
[plugins."io.containerd.grpc.v1.cri".containerd]
[plugins."io.containerd.grpc.v1.cri".containerd.runtimes]
[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.test]
privileged_without_host_devices = false
runtime_engine = ""
runtime_root = ""
runtime_type = ""
[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.test.options]
BinaryName = "/usr/bin/test"
SystemdCgroup = true
`,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
config, err := toml.Load(tc.config)
require.NoError(t, err)
expectedConfig, err := toml.Load(tc.expectedConfig)
require.NoError(t, err)
c := &Config{
Tree: config,
}
err = c.AddRuntime("test", "/usr/bin/test", tc.setAsDefault, tc.configOverrides...)
require.NoError(t, err)
require.EqualValues(t, expectedConfig.String(), config.String())
})
}
}

View File

@ -0,0 +1,56 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package containerd
import (
"fmt"
"github.com/pelletier/go-toml"
)
// tomlTree is an alias for toml.Tree that allows for extensions.
type tomlTree toml.Tree
func subtreeAtPath(c toml.Tree, path ...string) *tomlTree {
tree := c.GetPath(path).(*toml.Tree)
return (*tomlTree)(tree)
}
func (t *tomlTree) insert(other map[string]interface{}) error {
for key, value := range other {
if insertsubtree, ok := value.(map[string]interface{}); ok {
subtree := (*toml.Tree)(t).Get(key).(*toml.Tree)
return (*tomlTree)(subtree).insert(insertsubtree)
}
(*toml.Tree)(t).Set(key, value)
}
return nil
}
func (t *tomlTree) applyOverrides(overrides ...map[string]interface{}) error {
for _, override := range overrides {
subconfig, err := toml.TreeFromMap(override)
if err != nil {
return fmt.Errorf("invalid toml config: %w", err)
}
if err := t.insert(subconfig.ToMap()); err != nil {
return err
}
}
return nil
}

View File

@ -40,7 +40,7 @@ func New(opts ...Option) (engine.Interface, error) {
}
// AddRuntime adds a new runtime to the crio config
func (c *Config) AddRuntime(name string, path string, setAsDefault bool) error {
func (c *Config) AddRuntime(name string, path string, setAsDefault bool, _ ...map[string]interface{}) error {
if c == nil {
return fmt.Errorf("config is nil")
}

View File

@ -49,7 +49,7 @@ func New(opts ...Option) (engine.Interface, error) {
}
// AddRuntime adds a new runtime to the docker config
func (c *Config) AddRuntime(name string, path string, setAsDefault bool) error {
func (c *Config) AddRuntime(name string, path string, setAsDefault bool, _ ...map[string]interface{}) error {
if c == nil {
return fmt.Errorf("config is nil")
}

View File

@ -36,12 +36,12 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
},
)
graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.nvidiaCTKPath)
graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.nvidiaCDIHookPath)
if err != nil {
l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err)
}
driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.ldconfigPath, l.nvmllib)
driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCDIHookPath, l.ldconfigPath, l.nvmllib)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err)
}

View File

@ -34,7 +34,7 @@ import (
// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
// The supplied NVML Library is used to query the expected driver version.
func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, ldconfigPath string, nvmllib nvml.Interface) (discover.Discover, error) {
func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string, ldconfigPath string, nvmllib nvml.Interface) (discover.Discover, error) {
if r := nvmllib.Init(); r != nvml.SUCCESS {
return nil, fmt.Errorf("failed to initialize NVML: %v", r)
}
@ -49,11 +49,11 @@ func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTK
return nil, fmt.Errorf("failed to determine driver version: %v", r)
}
return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, ldconfigPath, version)
return newDriverVersionDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
}
func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath, ldconfigPath, version string) (discover.Discover, error) {
libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCTKPath, ldconfigPath, version)
func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err)
}
@ -81,7 +81,7 @@ func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nv
}
// NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version.
func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath, ldconfigPath, version string) (discover.Discover, error) {
func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
libraryPaths, err := getVersionLibs(logger, driver, version)
if err != nil {
return nil, fmt.Errorf("failed to get libraries for driver version: %v", err)
@ -97,7 +97,7 @@ func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nv
libraryPaths,
)
hooks, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCTKPath, ldconfigPath)
hooks, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath)
d := discover.Merge(
libraries,

View File

@ -39,7 +39,7 @@ var requiredDriverStoreFiles = []string{
}
// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers.
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath, ldconfigPath string) (discover.Discover, error) {
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCDIHookPath, ldconfigPath string) (discover.Discover, error) {
err := dxcore.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize dxcore: %v", err)
@ -56,11 +56,11 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, nvidiaCT
}
logger.Infof("Using WSL driver store paths: %v", driverStorePaths)
return newWSLDriverStoreDiscoverer(logger, driverRoot, nvidiaCTKPath, ldconfigPath, driverStorePaths)
return newWSLDriverStoreDiscoverer(logger, driverRoot, nvidiaCDIHookPath, ldconfigPath, driverStorePaths)
}
// newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter.
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvidiaCTKPath string, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvidiaCDIHookPath string, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
var searchPaths []string
seen := make(map[string]bool)
for _, path := range driverStorePaths {
@ -88,12 +88,12 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi
)
symlinkHook := nvidiaSMISimlinkHook{
logger: logger,
mountsFrom: libraries,
nvidiaCTKPath: nvidiaCTKPath,
logger: logger,
mountsFrom: libraries,
nvidiaCDIHookPath: nvidiaCDIHookPath,
}
ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCTKPath, ldconfigPath)
ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath)
d := discover.Merge(
libraries,
@ -106,9 +106,9 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi
type nvidiaSMISimlinkHook struct {
discover.None
logger logger.Interface
mountsFrom discover.Discover
nvidiaCTKPath string
logger logger.Interface
mountsFrom discover.Discover
nvidiaCDIHookPath string
}
// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store.
@ -135,7 +135,7 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
}
link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCTKPath, links)
symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCDIHookPath, links)
return symlinkHook.Hooks()
}

View File

@ -92,8 +92,8 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
Path: "nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks",
"--link", "nvidia-smi::/usr/bin/nvidia-smi"},
},
},
@ -112,8 +112,8 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
Path: "nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
@ -132,8 +132,8 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
Path: "nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
@ -143,9 +143,9 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
m := nvidiaSMISimlinkHook{
logger: logger,
mountsFrom: tc.mounts,
nvidiaCTKPath: "nvidia-ctk",
logger: logger,
mountsFrom: tc.mounts,
nvidiaCDIHookPath: "nvidia-cdi-hook",
}
devices, err := m.Devices()

View File

@ -58,7 +58,7 @@ func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, err
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
device, err := newFullGPUDiscoverer(l.logger, l.devRoot, l.nvidiaCTKPath, d)
device, err := newFullGPUDiscoverer(l.logger, l.devRoot, l.nvidiaCDIHookPath, d)
if err != nil {
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
}
@ -73,17 +73,17 @@ func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error
// byPathHookDiscoverer discovers the entities required for injecting by-path DRM device links
type byPathHookDiscoverer struct {
logger logger.Interface
devRoot string
nvidiaCTKPath string
pciBusID string
deviceNodes discover.Discover
logger logger.Interface
devRoot string
nvidiaCDIHookPath string
pciBusID string
deviceNodes discover.Discover
}
var _ discover.Discover = (*byPathHookDiscoverer)(nil)
// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.
func newFullGPUDiscoverer(logger logger.Interface, devRoot string, nvidiaCTKPath string, d device.Device) (discover.Discover, error) {
func newFullGPUDiscoverer(logger logger.Interface, devRoot string, nvidiaCDIHookPath string, d device.Device) (discover.Discover, error) {
// TODO: The functionality to get device paths should be integrated into the go-nvlib/pkg/device.Device interface.
// This will allow reuse here and in other code where the paths are queried such as the NVIDIA device plugin.
minor, ret := d.GetMinorNumber()
@ -112,17 +112,17 @@ func newFullGPUDiscoverer(logger logger.Interface, devRoot string, nvidiaCTKPath
)
byPathHooks := &byPathHookDiscoverer{
logger: logger,
devRoot: devRoot,
nvidiaCTKPath: nvidiaCTKPath,
pciBusID: pciBusID,
deviceNodes: deviceNodes,
logger: logger,
devRoot: devRoot,
nvidiaCDIHookPath: nvidiaCDIHookPath,
pciBusID: pciBusID,
deviceNodes: deviceNodes,
}
deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer(
logger,
devRoot,
nvidiaCTKPath,
nvidiaCDIHookPath,
deviceNodes,
)
@ -157,8 +157,8 @@ func (d *byPathHookDiscoverer) Hooks() ([]discover.Hook, error) {
args = append(args, "--link", l)
}
hook := discover.CreateNvidiaCTKHook(
d.nvidiaCTKPath,
hook := discover.CreateNvidiaCDIHook(
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)

View File

@ -44,7 +44,7 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
tegra.WithLogger(l.logger),
tegra.WithDriverRoot(l.driverRoot),
tegra.WithDevRoot(l.devRoot),
tegra.WithNVIDIACTKPath(l.nvidiaCTKPath),
tegra.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath),
tegra.WithLdconfigPath(l.ldconfigPath),
tegra.WithCSVFiles(l.csvFiles),
tegra.WithLibrarySearchPaths(l.librarySearchPaths...),

View File

@ -54,7 +54,7 @@ func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
// GetCommonEdits generates a CDI specification that can be used for ANY devices
func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.ldconfigPath)
driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCDIHookPath, l.ldconfigPath)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for WSL driver: %v", err)
}

View File

@ -48,7 +48,7 @@ type nvcdilib struct {
deviceNamers DeviceNamers
driverRoot string
devRoot string
nvidiaCTKPath string
nvidiaCDIHookPath string
ldconfigPath string
configSearchPaths []string
librarySearchPaths []string
@ -81,24 +81,34 @@ func New(opts ...Option) (Interface, error) {
indexNamer, _ := NewDeviceNamer(DeviceNameStrategyIndex)
l.deviceNamers = []DeviceNamer{indexNamer}
}
if l.nvidiaCDIHookPath == "" {
l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
}
if l.driverRoot == "" {
l.driverRoot = "/"
}
if l.devRoot == "" {
l.devRoot = l.driverRoot
}
if l.nvidiaCTKPath == "" {
l.nvidiaCTKPath = "/usr/bin/nvidia-ctk"
}
if l.infolib == nil {
l.infolib = info.New()
}
l.driver = root.New(
root.WithLogger(l.logger),
root.WithDriverRoot(l.driverRoot),
root.WithLibrarySearchPaths(l.librarySearchPaths...),
)
if l.nvmllib == nil {
l.nvmllib = nvml.New()
}
if l.devicelib == nil {
l.devicelib = device.New(l.nvmllib)
}
if l.infolib == nil {
l.infolib = info.New(
info.WithRoot(l.driverRoot),
info.WithLogger(l.logger),
info.WithNvmlLib(l.nvmllib),
info.WithDeviceLib(l.devicelib),
)
}
var lib Interface
switch l.resolveMode() {
@ -113,13 +123,6 @@ func New(opts ...Option) (Interface, error) {
}
lib = (*managementlib)(l)
case ModeNvml:
if l.nvmllib == nil {
l.nvmllib = nvml.New()
}
if l.devicelib == nil {
l.devicelib = device.New(device.WithNvml(l.nvmllib))
}
lib = (*nvmllib)(l)
case ModeWsl:
lib = (*wsllib)(l)
@ -184,26 +187,19 @@ func (l *nvcdilib) resolveMode() (rmode string) {
return l.mode
}
defer func() {
l.logger.Infof("Auto-detected mode as %q", rmode)
l.logger.Infof("Auto-detected mode as '%v'", rmode)
}()
isWSL, reason := l.infolib.HasDXCore()
l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason)
if isWSL {
platform := l.infolib.ResolvePlatform()
switch platform {
case info.PlatformNVML:
return ModeNvml
case info.PlatformTegra:
return ModeCSV
case info.PlatformWSL:
return ModeWsl
}
isNvml, reason := l.infolib.HasNvml()
l.logger.Debugf("Is NVML-based system? %v: %v", isNvml, reason)
isTegra, reason := l.infolib.IsTegraSystem()
l.logger.Debugf("Is Tegra-based system? %v: %v", isTegra, reason)
if isTegra && !isNvml {
return ModeCSV
}
l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml)
return ModeNvml
}

View File

@ -1,116 +0,0 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvcdi
import (
"fmt"
"testing"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
func TestResolveMode(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
mode string
isTegra bool
hasDXCore bool
hasNVML bool
expected string
}{
{
mode: "auto",
hasDXCore: true,
expected: "wsl",
},
{
mode: "auto",
hasDXCore: false,
isTegra: true,
hasNVML: false,
expected: "csv",
},
{
mode: "auto",
hasDXCore: false,
isTegra: false,
hasNVML: false,
expected: "nvml",
},
{
mode: "auto",
hasDXCore: false,
isTegra: true,
hasNVML: true,
expected: "nvml",
},
{
mode: "auto",
hasDXCore: false,
isTegra: false,
expected: "nvml",
},
{
mode: "nvml",
hasDXCore: true,
isTegra: true,
expected: "nvml",
},
{
mode: "wsl",
hasDXCore: false,
expected: "wsl",
},
{
mode: "not-auto",
hasDXCore: true,
expected: "not-auto",
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
l := nvcdilib{
logger: logger,
mode: tc.mode,
infolib: infoMock{hasDXCore: tc.hasDXCore, isTegra: tc.isTegra, hasNVML: tc.hasNVML},
}
require.Equal(t, tc.expected, l.resolveMode())
})
}
}
type infoMock struct {
hasDXCore bool
isTegra bool
hasNVML bool
}
func (i infoMock) HasDXCore() (bool, string) {
return i.hasDXCore, ""
}
func (i infoMock) HasNvml() (bool, string) {
return i.hasNVML, ""
}
func (i infoMock) IsTegraSystem() (bool, string) {
return i.isTegra, ""
}

View File

@ -66,7 +66,7 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("failed to get CUDA version: %v", err)
}
driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, m.ldconfigPath, version)
driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCDIHookPath, m.ldconfigPath, version)
if err != nil {
return nil, fmt.Errorf("failed to create driver library discoverer: %v", err)
}
@ -123,7 +123,7 @@ func (m *managementlib) newManagementDeviceDiscoverer() (discover.Discover, erro
deviceFolderPermissionHooks := newDeviceFolderPermissionHookDiscoverer(
m.logger,
m.devRoot,
m.nvidiaCTKPath,
m.nvidiaCDIHookPath,
deviceNodes,
)

View File

@ -18,6 +18,7 @@ package nvcdi
import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
@ -34,6 +35,13 @@ func WithDeviceLib(devicelib device.Interface) Option {
}
}
// WithInfoLib sets the info library for CDI spec generation.
func WithInfoLib(infolib info.Interface) Option {
return func(l *nvcdilib) {
l.infolib = infolib
}
}
// WithDeviceNamers sets the device namer for the library
func WithDeviceNamers(namers ...DeviceNamer) Option {
return func(l *nvcdilib) {
@ -63,9 +71,16 @@ func WithLogger(logger logger.Interface) Option {
}
// WithNVIDIACTKPath sets the path to the NVIDIA Container Toolkit CLI path for the library
//
// Deprecated: Use WithNVIDIACDIHookPath instead.
func WithNVIDIACTKPath(path string) Option {
return WithNVIDIACDIHookPath(path)
}
// WithNVIDIACDIHookPath sets the path to the NVIDIA Container Toolkit CLI path for the library
func WithNVIDIACDIHookPath(path string) Option {
return func(l *nvcdilib) {
l.nvidiaCTKPath = path
l.nvidiaCDIHookPath = path
}
}

View File

@ -24,6 +24,8 @@ import (
const (
// DetectMinimumVersion is a constant that triggers a spec to detect the minimum required version.
//
// Deprecated: DetectMinimumVersion is deprecated and will be removed.
DetectMinimumVersion = "DETECT_MINIMUM_VERSION"
// FormatJSON indicates a JSON output format

View File

@ -39,6 +39,8 @@ type builder struct {
mergedDeviceOptions []transform.MergedDeviceOption
noSimplify bool
permissions os.FileMode
transformOnSave transform.Transformer
}
// newBuilder creates a new spec builder with the supplied options
@ -47,15 +49,23 @@ func newBuilder(opts ...Option) *builder {
for _, opt := range opts {
opt(s)
}
if s.raw != nil {
s.noSimplify = true
vendor, class := parser.ParseQualifier(s.raw.Kind)
s.vendor = vendor
s.class = class
if s.vendor == "" {
s.vendor = vendor
}
if s.class == "" {
s.class = class
}
if s.version == "" || s.version == DetectMinimumVersion {
s.version = s.raw.Version
}
}
if s.version == "" {
s.version = DetectMinimumVersion
if s.version == "" || s.version == DetectMinimumVersion {
s.transformOnSave = &setMinimumRequiredVersion{}
s.version = cdi.CurrentVersion
}
if s.vendor == "" {
s.vendor = "nvidia.com"
@ -83,13 +93,8 @@ func (o *builder) Build() (*spec, error) {
ContainerEdits: o.edits,
}
}
if raw.Version == DetectMinimumVersion {
minVersion, err := cdi.MinimumRequiredVersion(raw)
if err != nil {
return nil, fmt.Errorf("failed to get minimum required CDI spec version: %v", err)
}
raw.Version = minVersion
if raw.Version == "" {
raw.Version = o.version
}
if !o.noSimplify {
@ -110,11 +115,11 @@ func (o *builder) Build() (*spec, error) {
}
s := spec{
Spec: raw,
format: o.format,
permissions: o.permissions,
Spec: raw,
format: o.format,
permissions: o.permissions,
transformOnSave: o.transformOnSave,
}
return &s, nil
}

View File

@ -0,0 +1,35 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"fmt"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"
)
type setMinimumRequiredVersion struct{}
func (d setMinimumRequiredVersion) Transform(spec *specs.Spec) error {
minVersion, err := cdi.MinimumRequiredVersion(spec)
if err != nil {
return fmt.Errorf("failed to get minimum required CDI spec version: %v", err)
}
spec.Version = minVersion
return nil
}

View File

@ -24,12 +24,15 @@ import (
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
)
type spec struct {
*specs.Spec
format string
permissions os.FileMode
format string
permissions os.FileMode
transformOnSave transform.Transformer
}
var _ Interface = (*spec)(nil)
@ -41,6 +44,12 @@ func New(opts ...Option) (Interface, error) {
// Save writes the spec to the specified path and overwrites the file if it exists.
func (s *spec) Save(path string) error {
if s.transformOnSave != nil {
err := s.transformOnSave.Transform(s.Raw())
if err != nil {
return fmt.Errorf("error applying transform: %w", err)
}
}
path, err := s.normalizePath(path)
if err != nil {
return fmt.Errorf("failed to normalize path: %w", err)

181
pkg/nvcdi/spec/spec_test.go Normal file
View File

@ -0,0 +1,181 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package spec
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
"tags.cncf.io/container-device-interface/specs-go"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform/root"
)
func TestSpec(t *testing.T) {
testCases := []struct {
description string
options []Option
expectedNewError error
transform transform.Transformer
expectedSpec string
}{
{
description: "default options return empty spec",
expectedSpec: `---
cdiVersion: 0.3.0
containerEdits: {}
devices: null
kind: nvidia.com/gpu
`,
},
{
description: "version is overridden",
options: []Option{WithVersion("0.5.0")},
expectedSpec: `---
cdiVersion: 0.5.0
containerEdits: {}
devices: null
kind: nvidia.com/gpu
`,
},
{
description: "raw spec is used as is",
options: []Option{WithRawSpec(
&specs.Spec{
Version: "0.5.0",
Kind: "nvidia.com/gpu",
ContainerEdits: specs.ContainerEdits{
Env: []string{"FOO=bar"},
},
},
)},
expectedSpec: `---
cdiVersion: 0.5.0
containerEdits:
env:
- FOO=bar
devices: null
kind: nvidia.com/gpu
`,
},
{
description: "raw spec with no version uses minimum version",
options: []Option{WithRawSpec(
&specs.Spec{
Kind: "nvidia.com/gpu",
ContainerEdits: specs.ContainerEdits{
Env: []string{"FOO=bar"},
},
},
)},
expectedSpec: `---
cdiVersion: 0.3.0
containerEdits:
env:
- FOO=bar
devices: null
kind: nvidia.com/gpu
`,
},
{
description: "spec with host dev path uses 0.5.0 version",
options: []Option{WithRawSpec(
&specs.Spec{
Kind: "nvidia.com/gpu",
ContainerEdits: specs.ContainerEdits{
Env: []string{"FOO=bar"},
DeviceNodes: []*specs.DeviceNode{
{
HostPath: "/some/dev/dev0",
Path: "/dev/dev0",
},
},
},
},
)},
expectedSpec: `---
cdiVersion: 0.5.0
containerEdits:
deviceNodes:
- hostPath: /some/dev/dev0
path: /dev/dev0
env:
- FOO=bar
devices: null
kind: nvidia.com/gpu
`,
},
{
description: "transformed spec uses minimum version",
options: []Option{WithRawSpec(
&specs.Spec{
Kind: "nvidia.com/gpu",
ContainerEdits: specs.ContainerEdits{
Env: []string{"FOO=bar"},
DeviceNodes: []*specs.DeviceNode{
{
HostPath: "/some/dev/dev0",
Path: "/dev/dev0",
},
},
},
},
)},
transform: transform.Merge(
root.New(
root.WithRoot("/some/dev/"),
root.WithTargetRoot("/dev/"),
root.WithRelativeTo("host"),
),
transform.NewSimplifier(),
),
expectedSpec: `---
cdiVersion: 0.5.0
containerEdits:
deviceNodes:
- hostPath: /dev/dev0
path: /dev/dev0
env:
- FOO=bar
devices: null
kind: nvidia.com/gpu
`,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
s, err := New(tc.options...)
require.ErrorIs(t, err, tc.expectedNewError)
if tc.transform != nil {
err := tc.transform.Transform(s.Raw())
require.NoError(t, err)
}
buf := new(bytes.Buffer)
_, err = s.WriteTo(buf)
require.NoError(t, err)
require.EqualValues(t, tc.expectedSpec, buf.String())
})
}
}

View File

@ -98,13 +98,13 @@ func TestDeduplicate(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -114,8 +114,8 @@ func TestDeduplicate(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},

View File

@ -141,8 +141,8 @@ func TestMergedDevice(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -153,8 +153,8 @@ func TestMergedDevice(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -169,8 +169,8 @@ func TestMergedDevice(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -181,8 +181,8 @@ func TestMergedDevice(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -193,8 +193,8 @@ func TestMergedDevice(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},

View File

@ -165,27 +165,27 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/root/usr/bin/nvidia-ctk",
Path: "/root/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk", "hook", "update-ldcache",
"nvidia-cdi-hook", "update-ldcache",
"--folder",
"/root/path/to/target",
},
},
{
HookName: "createContainer",
Path: "/target-root/usr/bin/nvidia-ctk",
Path: "/target-root/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk", "hook", "update-ldcache",
"nvidia-cdi-hook", "update-ldcache",
"--folder",
"/target-root/path/to/target",
},
},
{
HookName: "createContainer",
Path: "/different-root/usr/bin/nvidia-ctk",
Path: "/different-root/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk", "hook", "update-ldcache",
"nvidia-cdi-hook", "update-ldcache",
"--folder",
"/different-root/path/to/target",
},
@ -198,27 +198,27 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/root/usr/bin/nvidia-ctk",
Path: "/root/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk", "hook", "update-ldcache",
"nvidia-cdi-hook", "update-ldcache",
"--folder",
"/target-root/path/to/target",
},
},
{
HookName: "createContainer",
Path: "/target-root/usr/bin/nvidia-ctk",
Path: "/target-root/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk", "hook", "update-ldcache",
"nvidia-cdi-hook", "update-ldcache",
"--folder",
"/target-root/path/to/target",
},
},
{
HookName: "createContainer",
Path: "/different-root/usr/bin/nvidia-ctk",
Path: "/different-root/usr/bin/nvidia-cdi-hook",
Args: []string{
"nvidia-ctk", "hook", "update-ldcache",
"nvidia-cdi-hook", "update-ldcache",
"--folder",
"/different-root/path/to/target",
},
@ -236,7 +236,7 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "startContainer",
Path: "/root/usr/bin/nvidia-ctk",
Path: "/root/usr/bin/nvidia-cdi-hook",
Args: []string{
"--link",
"/root/path/to/target::/root/path/to/link",
@ -250,7 +250,7 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "startContainer",
Path: "/target-root/usr/bin/nvidia-ctk",
Path: "/target-root/usr/bin/nvidia-cdi-hook",
Args: []string{
"--link",
"/target-root/path/to/target::/target-root/path/to/link",
@ -269,7 +269,7 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/root/usr/bin/nvidia-ctk",
Path: "/root/usr/bin/nvidia-cdi-hook",
Args: []string{
"--link",
"/root/path/to/target::/root/path/to/link",
@ -283,7 +283,7 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/root/usr/bin/nvidia-ctk",
Path: "/root/usr/bin/nvidia-cdi-hook",
Args: []string{
"--link",
"/target-root/path/to/target::/target-root/path/to/link",
@ -302,7 +302,7 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createRuntime",
Path: "/root/usr/bin/nvidia-ctk",
Path: "/root/usr/bin/nvidia-cdi-hook",
Args: []string{
"--link",
"/root/path/to/target::/root/path/to/link",
@ -316,7 +316,7 @@ func TestContainerRootTransformer(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createRuntime",
Path: "/root/usr/bin/nvidia-ctk",
Path: "/root/usr/bin/nvidia-cdi-hook",
Args: []string{
"--link",
"/root/path/to/target::/root/path/to/link",

View File

@ -115,8 +115,8 @@ func TestSimplify(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -127,8 +127,8 @@ func TestSimplify(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -139,13 +139,13 @@ func TestSimplify(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -160,8 +160,8 @@ func TestSimplify(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -172,8 +172,8 @@ func TestSimplify(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},
@ -184,8 +184,8 @@ func TestSimplify(t *testing.T) {
Hooks: []*specs.Hook{
{
HookName: "createContainer",
Path: "/usr/bin/nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
Path: "/usr/bin/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
},
},
},

View File

@ -25,10 +25,10 @@ import (
)
type deviceFolderPermissions struct {
logger logger.Interface
devRoot string
nvidiaCTKPath string
devices discover.Discover
logger logger.Interface
devRoot string
nvidiaCDIHookPath string
devices discover.Discover
}
var _ discover.Discover = (*deviceFolderPermissions)(nil)
@ -39,12 +39,12 @@ var _ discover.Discover = (*deviceFolderPermissions)(nil)
// The nested devices that are applicable to the NVIDIA GPU devices are:
// - DRM devices at /dev/dri/*
// - NVIDIA Caps devices at /dev/nvidia-caps/*
func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, nvidiaCTKPath string, devices discover.Discover) discover.Discover {
func newDeviceFolderPermissionHookDiscoverer(logger logger.Interface, devRoot string, nvidiaCDIHookPath string, devices discover.Discover) discover.Discover {
d := &deviceFolderPermissions{
logger: logger,
devRoot: devRoot,
nvidiaCTKPath: nvidiaCTKPath,
devices: devices,
logger: logger,
devRoot: devRoot,
nvidiaCDIHookPath: nvidiaCDIHookPath,
devices: devices,
}
return d
@ -70,8 +70,8 @@ func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) {
args = append(args, "--path", folder)
}
hook := discover.CreateNvidiaCTKHook(
d.nvidiaCTKPath,
hook := discover.CreateNvidiaCDIHook(
d.nvidiaCDIHookPath,
"chmod",
args...,
)

View File

@ -67,8 +67,8 @@ func ParseArgs(c *cli.Context, o *Options) error {
}
// Configure applies the options to the specified config
func (o Options) Configure(cfg engine.Interface) error {
err := o.UpdateConfig(cfg)
func (o Options) Configure(cfg engine.Interface, configOverrides ...map[string]interface{}) error {
err := o.UpdateConfig(cfg, configOverrides...)
if err != nil {
return fmt.Errorf("unable to update config: %v", err)
}
@ -98,14 +98,14 @@ func (o Options) flush(cfg engine.Interface) error {
}
// UpdateConfig updates the specified config to include the nvidia runtimes
func (o Options) UpdateConfig(cfg engine.Interface) error {
func (o Options) UpdateConfig(cfg engine.Interface, configOverrides ...map[string]interface{}) error {
runtimes := operator.GetRuntimes(
operator.WithNvidiaRuntimeName(o.RuntimeName),
operator.WithSetAsDefault(o.SetAsDefault),
operator.WithRoot(o.RuntimeDir),
)
for name, runtime := range runtimes {
err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault)
err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault, configOverrides...)
if err != nil {
return fmt.Errorf("failed to update runtime %q: %v", name, err)
}

View File

@ -17,6 +17,7 @@
package main
import (
"encoding/json"
"fmt"
"os"
@ -47,6 +48,8 @@ type options struct {
runtimeType string
ContainerRuntimeModesCDIAnnotationPrefixes cli.StringSlice
runtimeConfigOverrideJSON string
}
func main() {
@ -162,6 +165,13 @@ func main() {
Destination: &options.ContainerRuntimeModesCDIAnnotationPrefixes,
EnvVars: []string{"NVIDIA_CONTAINER_RUNTIME_MODES_CDI_ANNOTATION_PREFIXES"},
},
&cli.StringFlag{
Name: "runtime-config-override",
Destination: &options.runtimeConfigOverrideJSON,
Usage: "specify additional runtime options as a JSON string. The paths are relative to the runtime config.",
Value: "{}",
EnvVars: []string{"RUNTIME_CONFIG_OVERRIDE", "CONTAINERD_RUNTIME_CONFIG_OVERRIDE"},
},
}
// Update the subcommand flags with the common subcommand flags
@ -170,7 +180,7 @@ func main() {
// Run the top-level CLI
if err := c.Run(os.Args); err != nil {
log.Fatal(fmt.Errorf("Error: %v", err))
log.Fatal(fmt.Errorf("error: %v", err))
}
}
@ -188,7 +198,11 @@ func Setup(c *cli.Context, o *options) error {
return fmt.Errorf("unable to load config: %v", err)
}
err = o.Configure(cfg)
runtimeConfigOverride, err := o.runtimeConfigOverride()
if err != nil {
return fmt.Errorf("unable to parse config overrides: %w", err)
}
err = o.Configure(cfg, runtimeConfigOverride)
if err != nil {
return fmt.Errorf("unable to configure containerd: %v", err)
}
@ -246,3 +260,16 @@ func (o *options) containerAnnotationsFromCDIPrefixes() []string {
return annotations
}
func (o *options) runtimeConfigOverride() (map[string]interface{}, error) {
if o.runtimeConfigOverrideJSON == "" {
return nil, nil
}
runtimeOptions := make(map[string]interface{})
if err := json.Unmarshal([]byte(o.runtimeConfigOverrideJSON), &runtimeOptions); err != nil {
return nil, fmt.Errorf("failed to read %v as JSON: %w", o.runtimeConfigOverrideJSON, err)
}
return runtimeOptions, nil
}

View File

@ -0,0 +1,72 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package main
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestRuntimeOptions(t *testing.T) {
testCases := []struct {
description string
options options
expected map[string]interface{}
expectedError error
}{
{
description: "empty is nil",
},
{
description: "empty json",
options: options{
runtimeConfigOverrideJSON: "{}",
},
expected: map[string]interface{}{},
expectedError: nil,
},
{
description: "SystemdCgroup is true",
options: options{
runtimeConfigOverrideJSON: "{\"SystemdCgroup\": true}",
},
expected: map[string]interface{}{
"SystemdCgroup": true,
},
expectedError: nil,
},
{
description: "SystemdCgroup is false",
options: options{
runtimeConfigOverrideJSON: "{\"SystemdCgroup\": false}",
},
expected: map[string]interface{}{
"SystemdCgroup": false,
},
expectedError: nil,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
runtimeOptions, err := tc.options.runtimeConfigOverride()
require.ErrorIs(t, tc.expectedError, err)
require.EqualValues(t, tc.expected, runtimeOptions)
})
}
}

View File

@ -355,6 +355,13 @@ func Install(cli *cli.Context, opts *options) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA Container Toolkit CLI: %v", err))
}
nvidiaCDIHookPath, err := installContainerCDIHookCLI(opts.toolkitRoot)
if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA Container CDI Hook CLI: %v", err)
} else if err != nil {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA Container CDI Hook CLI: %v", err))
}
err = installToolkitConfig(cli, toolkitConfigPath, nvidiaContainerCliExecutable, nvidiaCTKPath, nvidiaContainerRuntimeHookPath, opts)
if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container toolkit config: %v", err)
@ -369,7 +376,7 @@ func Install(cli *cli.Context, opts *options) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error creating device nodes: %v", err))
}
err = generateCDISpec(opts, nvidiaCTKPath)
err = generateCDISpec(opts, nvidiaCDIHookPath)
if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error generating CDI specification: %v", err)
} else if err != nil {
@ -539,6 +546,19 @@ func installContainerToolkitCLI(toolkitDir string) (string, error) {
return e.install(toolkitDir)
}
// installContainerCDIHookCLI installs the nvidia-cdi-hook CLI executable and wrapper.
func installContainerCDIHookCLI(toolkitDir string) (string, error) {
e := executable{
source: "/usr/bin/nvidia-cdi-hook",
target: executableTarget{
dotfileName: "nvidia-cdi-hook.real",
wrapperName: "nvidia-cdi-hook",
},
}
return e.install(toolkitDir)
}
// installContainerCLI sets up the NVIDIA container CLI executable, copying the executable
// and implementing the required wrapper
func installContainerCLI(toolkitRoot string) (string, error) {
@ -749,8 +769,8 @@ func createDeviceNodes(opts *options) error {
return nil
}
// generateCDISpec generates a CDI spec for use in managemnt containers
func generateCDISpec(opts *options, nvidiaCTKPath string) error {
// generateCDISpec generates a CDI spec for use in management containers
func generateCDISpec(opts *options, nvidiaCDIHookPath string) error {
if !opts.cdiEnabled {
return nil
}
@ -758,7 +778,7 @@ func generateCDISpec(opts *options, nvidiaCTKPath string) error {
cdilib, err := nvcdi.New(
nvcdi.WithMode(nvcdi.ModeManagement),
nvcdi.WithDriverRoot(opts.DriverRootCtrPath),
nvcdi.WithNVIDIACTKPath(nvidiaCTKPath),
nvcdi.WithNVIDIACDIHookPath(nvidiaCDIHookPath),
nvcdi.WithVendor(opts.cdiVendor),
nvcdi.WithClass(opts.cdiClass),
)

View File

@ -38,7 +38,7 @@ type Interface interface {
}
type devicelib struct {
nvml nvml.Interface
nvmllib nvml.Interface
skippedDevices map[string]struct{}
verifySymbols *bool
migProfiles []MigProfile
@ -47,14 +47,13 @@ type devicelib struct {
var _ Interface = &devicelib{}
// New creates a new instance of the 'device' interface.
func New(opts ...Option) Interface {
d := &devicelib{}
func New(nvmllib nvml.Interface, opts ...Option) Interface {
d := &devicelib{
nvmllib: nvmllib,
}
for _, opt := range opts {
opt(d)
}
if d.nvml == nil {
d.nvml = nvml.New()
}
if d.verifySymbols == nil {
verify := true
d.verifySymbols = &verify
@ -68,13 +67,6 @@ func New(opts ...Option) Interface {
return d
}
// WithNvml provides an Option to set the NVML library used by the 'device' interface.
func WithNvml(nvml nvml.Interface) Option {
return func(d *devicelib) {
d.nvml = nvml
}
}
// WithVerifySymbols provides an option to toggle whether to verify select symbols exist in dynamic libraries before calling them.
func WithVerifySymbols(verify bool) Option {
return func(d *devicelib) {

View File

@ -51,7 +51,7 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
// NewDeviceByUUID builds a new Device from a UUID.
func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}
@ -334,13 +334,13 @@ func (d *device) isSkipped() (bool, error) {
// VisitDevices visits each top-level device and invokes a callback function for it.
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount()
count, ret := d.nvmllib.DeviceGetCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := d.nvml.DeviceGetHandleByIndex(i)
device, ret := d.nvmllib.DeviceGetHandleByIndex(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
}
@ -469,5 +469,5 @@ func (d *devicelib) hasSymbol(symbol string) bool {
return true
}
return d.nvml.Extensions().LookupSymbol(symbol) == nil
return d.nvmllib.Extensions().LookupSymbol(symbol) == nil
}

View File

@ -50,7 +50,7 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
// NewMigDeviceByUUID builds a new MigDevice from a UUID.
func (d *devicelib) NewMigDeviceByUUID(uuid string) (MigDevice, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}

View File

@ -0,0 +1,41 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
// Interface provides the API to the info package.
type Interface interface {
PlatformResolver
PropertyExtractor
}
// PlatformResolver defines a function to resolve the current platform.
type PlatformResolver interface {
ResolvePlatform() Platform
}
// PropertyExtractor provides a set of functions to query capabilities of the
// system.
//
//go:generate moq -rm -out property-extractor_mock.go . PropertyExtractor
type PropertyExtractor interface {
HasDXCore() (bool, string)
HasNvml() (bool, string)
HasTegraFiles() (bool, string)
// Deprecated: Use HasTegraFiles instead.
IsTegraSystem() (bool, string)
UsesOnlyNVGPUModule() (bool, string)
}

View File

@ -0,0 +1,78 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
)
type infolib struct {
PropertyExtractor
PlatformResolver
}
type options struct {
logger basicLogger
root root
nvmllib nvml.Interface
devicelib device.Interface
platform Platform
propertyExtractor PropertyExtractor
}
// New creates a new instance of the 'info' interface.
func New(opts ...Option) Interface {
o := &options{}
for _, opt := range opts {
opt(o)
}
if o.logger == nil {
o.logger = &nullLogger{}
}
if o.root == "" {
o.root = "/"
}
if o.nvmllib == nil {
o.nvmllib = nvml.New(
nvml.WithLibraryPath(o.root.tryResolveLibrary("libnvidia-ml.so.1")),
)
}
if o.devicelib == nil {
o.devicelib = device.New(o.nvmllib)
}
if o.platform == "" {
o.platform = PlatformAuto
}
if o.propertyExtractor == nil {
o.propertyExtractor = &propertyExtractor{
root: o.root,
nvmllib: o.nvmllib,
devicelib: o.devicelib,
}
}
return &infolib{
PlatformResolver: &platformResolver{
logger: o.logger,
platform: o.platform,
propertyExtractor: o.propertyExtractor,
},
PropertyExtractor: o.propertyExtractor,
}
}

View File

@ -1,102 +0,0 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/NVIDIA/go-nvml/pkg/dl"
)
// Interface provides the API to the info package.
type Interface interface {
HasDXCore() (bool, string)
HasNvml() (bool, string)
IsTegraSystem() (bool, string)
}
type infolib struct {
root string
}
var _ Interface = &infolib{}
// HasDXCore returns true if DXCore is detected on the system.
func (i *infolib) HasDXCore() (bool, string) {
const (
libraryName = "libdxcore.so"
)
if err := assertHasLibrary(libraryName); err != nil {
return false, fmt.Sprintf("could not load DXCore library: %v", err)
}
return true, "found DXCore library"
}
// HasNvml returns true if NVML is detected on the system.
func (i *infolib) HasNvml() (bool, string) {
const (
libraryName = "libnvidia-ml.so.1"
)
if err := assertHasLibrary(libraryName); err != nil {
return false, fmt.Sprintf("could not load NVML library: %v", err)
}
return true, "found NVML library"
}
// IsTegraSystem returns true if the system is detected as a Tegra-based system.
func (i *infolib) IsTegraSystem() (bool, string) {
tegraReleaseFile := filepath.Join(i.root, "/etc/nv_tegra_release")
tegraFamilyFile := filepath.Join(i.root, "/sys/devices/soc0/family")
if info, err := os.Stat(tegraReleaseFile); err == nil && !info.IsDir() {
return true, fmt.Sprintf("%v found", tegraReleaseFile)
}
if info, err := os.Stat(tegraFamilyFile); err != nil || info.IsDir() {
return false, fmt.Sprintf("%v file not found", tegraFamilyFile)
}
contents, err := os.ReadFile(tegraFamilyFile)
if err != nil {
return false, fmt.Sprintf("could not read %v", tegraFamilyFile)
}
if strings.HasPrefix(strings.ToLower(string(contents)), "tegra") {
return true, fmt.Sprintf("%v has 'tegra' prefix", tegraFamilyFile)
}
return false, fmt.Sprintf("%v has no 'tegra' prefix", tegraFamilyFile)
}
// assertHasLibrary returns an error if the specified library cannot be loaded.
func assertHasLibrary(libraryName string) error {
const (
libraryLoadFlags = dl.RTLD_LAZY
)
lib := dl.New(libraryName, libraryLoadFlags)
if err := lib.Open(); err != nil {
return err
}
defer lib.Close()
return nil
}

View File

@ -0,0 +1,28 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
type basicLogger interface {
Debugf(string, ...interface{})
Infof(string, ...interface{})
}
type nullLogger struct{}
func (n *nullLogger) Debugf(string, ...interface{}) {}
func (n *nullLogger) Infof(string, ...interface{}) {}

View File

@ -16,24 +16,55 @@
package info
// Option defines a function for passing options to the New() call.
type Option func(*infolib)
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
// New creates a new instance of the 'info' interface.
func New(opts ...Option) Interface {
i := &infolib{}
for _, opt := range opts {
opt(i)
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
)
// Option defines a function for passing options to the New() call.
type Option func(*options)
// WithDeviceLib sets the device library for the library.
func WithDeviceLib(devicelib device.Interface) Option {
return func(i *options) {
i.devicelib = devicelib
}
if i.root == "" {
i.root = "/"
}
// WithLogger sets the logger for the library.
func WithLogger(logger basicLogger) Option {
return func(i *options) {
i.logger = logger
}
}
// WithNvmlLib sets the nvml library for the library.
func WithNvmlLib(nvmllib nvml.Interface) Option {
return func(i *options) {
i.nvmllib = nvmllib
}
return i
}
// WithRoot provides a Option to set the root of the 'info' interface.
func WithRoot(root string) Option {
return func(i *infolib) {
i.root = root
func WithRoot(r string) Option {
return func(i *options) {
i.root = root(r)
}
}
// WithPropertyExtractor provides an Option to set the PropertyExtractor
// interface implementation.
// This is predominantly used for testing.
func WithPropertyExtractor(propertyExtractor PropertyExtractor) Option {
return func(i *options) {
i.propertyExtractor = propertyExtractor
}
}
// WithPlatform provides an option to set the platform explicitly.
func WithPlatform(platform Platform) Option {
return func(i *options) {
i.platform = platform
}
}

View File

@ -0,0 +1,143 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
import (
"fmt"
"os"
"strings"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
)
type propertyExtractor struct {
root root
nvmllib nvml.Interface
devicelib device.Interface
}
var _ PropertyExtractor = &propertyExtractor{}
// HasDXCore returns true if DXCore is detected on the system.
func (i *propertyExtractor) HasDXCore() (bool, string) {
const (
libraryName = "libdxcore.so"
)
if err := i.root.assertHasLibrary(libraryName); err != nil {
return false, fmt.Sprintf("could not load DXCore library: %v", err)
}
return true, "found DXCore library"
}
// HasNvml returns true if NVML is detected on the system.
func (i *propertyExtractor) HasNvml() (bool, string) {
const (
libraryName = "libnvidia-ml.so.1"
)
if err := i.root.assertHasLibrary(libraryName); err != nil {
return false, fmt.Sprintf("could not load NVML library: %v", err)
}
return true, "found NVML library"
}
// IsTegraSystem returns true if the system is detected as a Tegra-based system.
// Deprecated: Use HasTegraFiles instead.
func (i *propertyExtractor) IsTegraSystem() (bool, string) {
return i.HasTegraFiles()
}
// HasTegraFiles returns true if tegra-based files are detected on the system.
func (i *propertyExtractor) HasTegraFiles() (bool, string) {
tegraReleaseFile := i.root.join("/etc/nv_tegra_release")
tegraFamilyFile := i.root.join("/sys/devices/soc0/family")
if info, err := os.Stat(tegraReleaseFile); err == nil && !info.IsDir() {
return true, fmt.Sprintf("%v found", tegraReleaseFile)
}
if info, err := os.Stat(tegraFamilyFile); err != nil || info.IsDir() {
return false, fmt.Sprintf("%v file not found", tegraFamilyFile)
}
contents, err := os.ReadFile(tegraFamilyFile)
if err != nil {
return false, fmt.Sprintf("could not read %v", tegraFamilyFile)
}
if strings.HasPrefix(strings.ToLower(string(contents)), "tegra") {
return true, fmt.Sprintf("%v has 'tegra' prefix", tegraFamilyFile)
}
return false, fmt.Sprintf("%v has no 'tegra' prefix", tegraFamilyFile)
}
// UsesOnlyNVGPUModule checks whether the only the nvgpu module is used.
// This kernel module is used on Tegra-based systems when using the iGPU.
// Since some of these systems also support NVML, we use the device name
// reported by NVML to determine whether the system is an iGPU system.
//
// Devices that use the nvgpu module have their device names as:
//
// GPU 0: Orin (nvgpu) (UUID: 54d0709b-558d-5a59-9c65-0c5fc14a21a4)
//
// This function returns true if ALL devices use the nvgpu module.
func (i *propertyExtractor) UsesOnlyNVGPUModule() (uses bool, reason string) {
// We ensure that this function never panics
defer func() {
if err := recover(); err != nil {
uses = false
reason = fmt.Sprintf("panic: %v", err)
}
}()
ret := i.nvmllib.Init()
if ret != nvml.SUCCESS {
return false, fmt.Sprintf("failed to initialize nvml: %v", ret)
}
defer func() {
_ = i.nvmllib.Shutdown()
}()
var names []string
err := i.devicelib.VisitDevices(func(i int, d device.Device) error {
name, ret := d.GetName()
if ret != nvml.SUCCESS {
return fmt.Errorf("device %v: %v", i, ret)
}
names = append(names, name)
return nil
})
if err != nil {
return false, fmt.Sprintf("failed to get device names: %v", err)
}
if len(names) == 0 {
return false, "no devices found"
}
for _, name := range names {
if !strings.Contains(name, "(nvgpu)") {
return false, fmt.Sprintf("device %q does not use nvgpu module", name)
}
}
return true, "all devices use nvgpu module"
}

View File

@ -0,0 +1,215 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package info
import (
"sync"
)
// Ensure, that PropertyExtractorMock does implement PropertyExtractor.
// If this is not the case, regenerate this file with moq.
var _ PropertyExtractor = &PropertyExtractorMock{}
// PropertyExtractorMock is a mock implementation of PropertyExtractor.
//
// func TestSomethingThatUsesPropertyExtractor(t *testing.T) {
//
// // make and configure a mocked PropertyExtractor
// mockedPropertyExtractor := &PropertyExtractorMock{
// HasDXCoreFunc: func() (bool, string) {
// panic("mock out the HasDXCore method")
// },
// HasNvmlFunc: func() (bool, string) {
// panic("mock out the HasNvml method")
// },
// HasTegraFilesFunc: func() (bool, string) {
// panic("mock out the HasTegraFiles method")
// },
// IsTegraSystemFunc: func() (bool, string) {
// panic("mock out the IsTegraSystem method")
// },
// UsesOnlyNVGPUModuleFunc: func() (bool, string) {
// panic("mock out the UsesOnlyNVGPUModule method")
// },
// }
//
// // use mockedPropertyExtractor in code that requires PropertyExtractor
// // and then make assertions.
//
// }
type PropertyExtractorMock struct {
// HasDXCoreFunc mocks the HasDXCore method.
HasDXCoreFunc func() (bool, string)
// HasNvmlFunc mocks the HasNvml method.
HasNvmlFunc func() (bool, string)
// HasTegraFilesFunc mocks the HasTegraFiles method.
HasTegraFilesFunc func() (bool, string)
// IsTegraSystemFunc mocks the IsTegraSystem method.
IsTegraSystemFunc func() (bool, string)
// UsesOnlyNVGPUModuleFunc mocks the UsesOnlyNVGPUModule method.
UsesOnlyNVGPUModuleFunc func() (bool, string)
// calls tracks calls to the methods.
calls struct {
// HasDXCore holds details about calls to the HasDXCore method.
HasDXCore []struct {
}
// HasNvml holds details about calls to the HasNvml method.
HasNvml []struct {
}
// HasTegraFiles holds details about calls to the HasTegraFiles method.
HasTegraFiles []struct {
}
// IsTegraSystem holds details about calls to the IsTegraSystem method.
IsTegraSystem []struct {
}
// UsesOnlyNVGPUModule holds details about calls to the UsesOnlyNVGPUModule method.
UsesOnlyNVGPUModule []struct {
}
}
lockHasDXCore sync.RWMutex
lockHasNvml sync.RWMutex
lockHasTegraFiles sync.RWMutex
lockIsTegraSystem sync.RWMutex
lockUsesOnlyNVGPUModule sync.RWMutex
}
// HasDXCore calls HasDXCoreFunc.
func (mock *PropertyExtractorMock) HasDXCore() (bool, string) {
if mock.HasDXCoreFunc == nil {
panic("PropertyExtractorMock.HasDXCoreFunc: method is nil but PropertyExtractor.HasDXCore was just called")
}
callInfo := struct {
}{}
mock.lockHasDXCore.Lock()
mock.calls.HasDXCore = append(mock.calls.HasDXCore, callInfo)
mock.lockHasDXCore.Unlock()
return mock.HasDXCoreFunc()
}
// HasDXCoreCalls gets all the calls that were made to HasDXCore.
// Check the length with:
//
// len(mockedPropertyExtractor.HasDXCoreCalls())
func (mock *PropertyExtractorMock) HasDXCoreCalls() []struct {
} {
var calls []struct {
}
mock.lockHasDXCore.RLock()
calls = mock.calls.HasDXCore
mock.lockHasDXCore.RUnlock()
return calls
}
// HasNvml calls HasNvmlFunc.
func (mock *PropertyExtractorMock) HasNvml() (bool, string) {
if mock.HasNvmlFunc == nil {
panic("PropertyExtractorMock.HasNvmlFunc: method is nil but PropertyExtractor.HasNvml was just called")
}
callInfo := struct {
}{}
mock.lockHasNvml.Lock()
mock.calls.HasNvml = append(mock.calls.HasNvml, callInfo)
mock.lockHasNvml.Unlock()
return mock.HasNvmlFunc()
}
// HasNvmlCalls gets all the calls that were made to HasNvml.
// Check the length with:
//
// len(mockedPropertyExtractor.HasNvmlCalls())
func (mock *PropertyExtractorMock) HasNvmlCalls() []struct {
} {
var calls []struct {
}
mock.lockHasNvml.RLock()
calls = mock.calls.HasNvml
mock.lockHasNvml.RUnlock()
return calls
}
// HasTegraFiles calls HasTegraFilesFunc.
func (mock *PropertyExtractorMock) HasTegraFiles() (bool, string) {
if mock.HasTegraFilesFunc == nil {
panic("PropertyExtractorMock.HasTegraFilesFunc: method is nil but PropertyExtractor.HasTegraFiles was just called")
}
callInfo := struct {
}{}
mock.lockHasTegraFiles.Lock()
mock.calls.HasTegraFiles = append(mock.calls.HasTegraFiles, callInfo)
mock.lockHasTegraFiles.Unlock()
return mock.HasTegraFilesFunc()
}
// HasTegraFilesCalls gets all the calls that were made to HasTegraFiles.
// Check the length with:
//
// len(mockedPropertyExtractor.HasTegraFilesCalls())
func (mock *PropertyExtractorMock) HasTegraFilesCalls() []struct {
} {
var calls []struct {
}
mock.lockHasTegraFiles.RLock()
calls = mock.calls.HasTegraFiles
mock.lockHasTegraFiles.RUnlock()
return calls
}
// IsTegraSystem calls IsTegraSystemFunc.
func (mock *PropertyExtractorMock) IsTegraSystem() (bool, string) {
if mock.IsTegraSystemFunc == nil {
panic("PropertyExtractorMock.IsTegraSystemFunc: method is nil but PropertyExtractor.IsTegraSystem was just called")
}
callInfo := struct {
}{}
mock.lockIsTegraSystem.Lock()
mock.calls.IsTegraSystem = append(mock.calls.IsTegraSystem, callInfo)
mock.lockIsTegraSystem.Unlock()
return mock.IsTegraSystemFunc()
}
// IsTegraSystemCalls gets all the calls that were made to IsTegraSystem.
// Check the length with:
//
// len(mockedPropertyExtractor.IsTegraSystemCalls())
func (mock *PropertyExtractorMock) IsTegraSystemCalls() []struct {
} {
var calls []struct {
}
mock.lockIsTegraSystem.RLock()
calls = mock.calls.IsTegraSystem
mock.lockIsTegraSystem.RUnlock()
return calls
}
// UsesOnlyNVGPUModule calls UsesOnlyNVGPUModuleFunc.
func (mock *PropertyExtractorMock) UsesOnlyNVGPUModule() (bool, string) {
if mock.UsesOnlyNVGPUModuleFunc == nil {
panic("PropertyExtractorMock.UsesOnlyNVGPUModuleFunc: method is nil but PropertyExtractor.UsesOnlyNVGPUModule was just called")
}
callInfo := struct {
}{}
mock.lockUsesOnlyNVGPUModule.Lock()
mock.calls.UsesOnlyNVGPUModule = append(mock.calls.UsesOnlyNVGPUModule, callInfo)
mock.lockUsesOnlyNVGPUModule.Unlock()
return mock.UsesOnlyNVGPUModuleFunc()
}
// UsesOnlyNVGPUModuleCalls gets all the calls that were made to UsesOnlyNVGPUModule.
// Check the length with:
//
// len(mockedPropertyExtractor.UsesOnlyNVGPUModuleCalls())
func (mock *PropertyExtractorMock) UsesOnlyNVGPUModuleCalls() []struct {
} {
var calls []struct {
}
mock.lockUsesOnlyNVGPUModule.RLock()
calls = mock.calls.UsesOnlyNVGPUModule
mock.lockUsesOnlyNVGPUModule.RUnlock()
return calls
}

View File

@ -0,0 +1,64 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
// Platform represents a supported plaform.
type Platform string
const (
PlatformAuto = Platform("auto")
PlatformNVML = Platform("nvml")
PlatformTegra = Platform("tegra")
PlatformWSL = Platform("wsl")
PlatformUnknown = Platform("unknown")
)
type platformResolver struct {
logger basicLogger
platform Platform
propertyExtractor PropertyExtractor
}
func (p platformResolver) ResolvePlatform() Platform {
if p.platform != PlatformAuto {
p.logger.Infof("Using requested platform '%s'", p.platform)
return p.platform
}
hasDXCore, reason := p.propertyExtractor.HasDXCore()
p.logger.Debugf("Is WSL-based system? %v: %v", hasDXCore, reason)
hasTegraFiles, reason := p.propertyExtractor.HasTegraFiles()
p.logger.Debugf("Is Tegra-based system? %v: %v", hasTegraFiles, reason)
hasNVML, reason := p.propertyExtractor.HasNvml()
p.logger.Debugf("Is NVML-based system? %v: %v", hasNVML, reason)
usesOnlyNVGPUModule, reason := p.propertyExtractor.UsesOnlyNVGPUModule()
p.logger.Debugf("Uses nvgpu kernel module? %v: %v", usesOnlyNVGPUModule, reason)
switch {
case hasDXCore:
return PlatformWSL
case (hasTegraFiles && !hasNVML), usesOnlyNVGPUModule:
return PlatformTegra
case hasNVML:
return PlatformNVML
default:
return PlatformUnknown
}
}

View File

@ -0,0 +1,86 @@
/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
import (
"fmt"
"path/filepath"
"github.com/NVIDIA/go-nvml/pkg/dl"
)
// root represents a directory on the filesystem relative to which libraries
// such as the NVIDIA driver libraries can be found.
type root string
func (r root) join(parts ...string) string {
return filepath.Join(append([]string{string(r)}, parts...)...)
}
// assertHasLibrary returns an error if the specified library cannot be loaded.
func (r root) assertHasLibrary(libraryName string) error {
const (
libraryLoadFlags = dl.RTLD_LAZY
)
lib := dl.New(r.tryResolveLibrary(libraryName), libraryLoadFlags)
if err := lib.Open(); err != nil {
return err
}
defer lib.Close()
return nil
}
// tryResolveLibrary attempts to locate the specified library in the root.
// If the root is not specified, is "/", or the library cannot be found in the
// set of predefined paths, the input is returned as is.
func (r root) tryResolveLibrary(libraryName string) string {
if r == "" || r == "/" {
return libraryName
}
librarySearchPaths := []string{
"/usr/lib64",
"/usr/lib/x86_64-linux-gnu",
"/usr/lib/aarch64-linux-gnu",
"/lib64",
"/lib/x86_64-linux-gnu",
"/lib/aarch64-linux-gnu",
}
for _, d := range librarySearchPaths {
l := r.join(d, libraryName)
resolved, err := resolveLink(l)
if err != nil {
continue
}
return resolved
}
return libraryName
}
// resolveLink finds the target of a symlink or the file itself in the
// case of a regular file.
// This is equivalent to running `readlink -f ${l}`.
func resolveLink(l string) (string, error) {
resolved, err := filepath.EvalSymlinks(l)
if err != nil {
return "", fmt.Errorf("error resolving link '%v': %w", l, err)
}
return resolved, nil
}

View File

@ -15,9 +15,48 @@
package nvml
import (
"fmt"
"reflect"
"unsafe"
)
// nvmlDeviceHandle attempts to convert a device d to an nvmlDevice.
// This is required for functions such as GetTopologyCommonAncestor which
// accept Device arguments that need to be passed to internal nvml* functions
// as nvmlDevice parameters.
func nvmlDeviceHandle(d Device) nvmlDevice {
var helper func(val reflect.Value) nvmlDevice
helper = func(val reflect.Value) nvmlDevice {
if val.Kind() == reflect.Interface {
val = val.Elem()
}
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Type() == reflect.TypeOf(nvmlDevice{}) {
return val.Interface().(nvmlDevice)
}
if val.Kind() != reflect.Struct {
panic(fmt.Errorf("unable to convert non-struct type %v to nvmlDevice", val.Kind()))
}
for i := 0; i < val.Type().NumField(); i++ {
if !val.Type().Field(i).Anonymous {
continue
}
if !val.Field(i).Type().Implements(reflect.TypeOf((*Device)(nil)).Elem()) {
continue
}
return helper(val.Field(i))
}
panic(fmt.Errorf("unable to convert %T to nvmlDevice", d))
}
return helper(reflect.ValueOf(d))
}
// EccBitType
type EccBitType = MemoryErrorType
@ -220,10 +259,13 @@ func (l *library) DeviceGetTopologyCommonAncestor(device1 Device, device2 Device
func (device1 nvmlDevice) GetTopologyCommonAncestor(device2 Device) (GpuTopologyLevel, Return) {
var pathInfo GpuTopologyLevel
ret := nvmlDeviceGetTopologyCommonAncestor(device1, device2.(nvmlDevice), &pathInfo)
ret := nvmlDeviceGetTopologyCommonAncestorStub(device1, nvmlDeviceHandle(device2), &pathInfo)
return pathInfo, ret
}
// nvmlDeviceGetTopologyCommonAncestorStub allows us to override this for testing.
var nvmlDeviceGetTopologyCommonAncestorStub = nvmlDeviceGetTopologyCommonAncestor
// nvml.DeviceGetTopologyNearestGpus()
func (l *library) DeviceGetTopologyNearestGpus(device Device, level GpuTopologyLevel) ([]Device, Return) {
return device.GetTopologyNearestGpus(level)
@ -250,7 +292,7 @@ func (l *library) DeviceGetP2PStatus(device1 Device, device2 Device, p2pIndex Gp
func (device1 nvmlDevice) GetP2PStatus(device2 Device, p2pIndex GpuP2PCapsIndex) (GpuP2PStatus, Return) {
var p2pStatus GpuP2PStatus
ret := nvmlDeviceGetP2PStatus(device1, device2.(nvmlDevice), p2pIndex, &p2pStatus)
ret := nvmlDeviceGetP2PStatus(device1, nvmlDeviceHandle(device2), p2pIndex, &p2pStatus)
return p2pStatus, ret
}
@ -1182,7 +1224,7 @@ func (l *library) DeviceOnSameBoard(device1 Device, device2 Device) (int, Return
func (device1 nvmlDevice) OnSameBoard(device2 Device) (int, Return) {
var onSameBoard int32
ret := nvmlDeviceOnSameBoard(device1, device2.(nvmlDevice), &onSameBoard)
ret := nvmlDeviceOnSameBoard(device1, nvmlDeviceHandle(device2), &onSameBoard)
return int(onSameBoard), ret
}

View File

@ -93,7 +93,7 @@ func (device nvmlDevice) GpmSampleGet(gpmSample GpmSample) Return {
}
func (gpmSample nvmlGpmSample) Get(device Device) Return {
return nvmlGpmSampleGet(device.(nvmlDevice), gpmSample)
return nvmlGpmSampleGet(nvmlDeviceHandle(device), gpmSample)
}
// nvml.GpmQueryDeviceSupport()
@ -137,5 +137,5 @@ func (device nvmlDevice) GpmMigSampleGet(gpuInstanceId int, gpmSample GpmSample)
}
func (gpmSample nvmlGpmSample) MigGet(device Device, gpuInstanceId int) Return {
return nvmlGpmMigSampleGet(device.(nvmlDevice), uint32(gpuInstanceId), gpmSample)
return nvmlGpmMigSampleGet(nvmlDeviceHandle(device), uint32(gpuInstanceId), gpmSample)
}

View File

@ -142,7 +142,7 @@ func (device nvmlDevice) VgpuTypeGetMaxInstances(vgpuTypeId VgpuTypeId) (int, Re
func (vgpuTypeId nvmlVgpuTypeId) GetMaxInstances(device Device) (int, Return) {
var vgpuInstanceCount uint32
ret := nvmlVgpuTypeGetMaxInstances(device.(nvmlDevice), vgpuTypeId, &vgpuInstanceCount)
ret := nvmlVgpuTypeGetMaxInstances(nvmlDeviceHandle(device), vgpuTypeId, &vgpuInstanceCount)
return int(vgpuInstanceCount), ret
}

4
vendor/modules.txt vendored
View File

@ -1,4 +1,4 @@
# github.com/NVIDIA/go-nvlib v0.3.0
# github.com/NVIDIA/go-nvlib v0.5.0
## explicit; go 1.20
github.com/NVIDIA/go-nvlib/pkg/nvlib/device
github.com/NVIDIA/go-nvlib/pkg/nvlib/info
@ -6,7 +6,7 @@ github.com/NVIDIA/go-nvlib/pkg/nvpci
github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes
github.com/NVIDIA/go-nvlib/pkg/nvpci/mmio
github.com/NVIDIA/go-nvlib/pkg/pciids
# github.com/NVIDIA/go-nvml v0.12.0-5
# github.com/NVIDIA/go-nvml v0.12.0-6
## explicit; go 1.20
github.com/NVIDIA/go-nvml/pkg/dl
github.com/NVIDIA/go-nvml/pkg/nvml