diff --git a/CHANGELOG.md b/CHANGELOG.md index d9e48600..9aaba665 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ * Add support for generating merged devices (e.g. `all` device) to the nvcdi API. * Use *.* pattern to locate libcuda.so when generating a CDI specification to support platforms where a patch version is not specified. * Update go-nvlib to skip devices that are not MIG capable when generating CDI specifications. +* Add `nvidia-container-runtime-hook.path` config option to specify NVIDIA Container Runtime Hook path explicitly. ## v1.13.1 diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index 05809aee..b8037751 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -172,7 +172,7 @@ func TestDuplicateHook(t *testing.T) { // addNVIDIAHook is a basic wrapper for an addHookModifier that is used for // testing. func addNVIDIAHook(spec *specs.Spec) error { - m := modifier.NewStableRuntimeModifier(logrus.StandardLogger()) + m := modifier.NewStableRuntimeModifier(logrus.StandardLogger(), nvidiaHook) return m.Modify(spec) } diff --git a/internal/config/config.go b/internal/config/config.go index 7601d0fd..21763918 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,6 +37,9 @@ const ( nvidiaCTKExecutable = "nvidia-ctk" nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk" + + nvidiaContainerRuntimeHookExecutable = "nvidia-container-runtime-hook" + nvidiaContainerRuntimeHookDefaultPath = "/usr/bin/nvidia-container-runtime-hook" ) var ( @@ -189,6 +192,9 @@ func GetDefaultConfigToml() (*toml.Tree, error) { // nvidia-ctk tree.Set("nvidia-ctk.path", nvidiaCTKExecutable) + // nvidia-container-runtime-hook + tree.Set("nvidia-container-runtime-hook.path", nvidiaContainerRuntimeHookExecutable) + return tree, nil } @@ -244,27 +250,48 @@ func getDistIDLike() []string { // If the path is specified as an absolute path, it is used directly // without checking for existence of an executable at that path. func ResolveNVIDIACTKPath(logger *logrus.Logger, nvidiaCTKPath string) string { - if filepath.IsAbs(nvidiaCTKPath) { - logger.Debugf("Using specified NVIDIA Container Toolkit CLI path %v", nvidiaCTKPath) - return nvidiaCTKPath - } - - if nvidiaCTKPath == "" { - nvidiaCTKPath = nvidiaCTKExecutable - } - logger.Debugf("Locating NVIDIA Container Toolkit CLI as %v", nvidiaCTKPath) - lookup := lookup.NewExecutableLocator(logger, "") - hookPath := nvidiaCTKDefaultFilePath - targets, err := lookup.Locate(nvidiaCTKPath) - if err != nil { - logger.Warnf("Failed to locate %v: %v", nvidiaCTKPath, err) - } else if len(targets) == 0 { - logger.Warnf("%v not found", nvidiaCTKPath) - } else { - logger.Debugf("Found %v candidates: %v", nvidiaCTKPath, targets) - hookPath = targets[0] - } - logger.Debugf("Using NVIDIA Container Toolkit CLI path %v", hookPath) - - return hookPath + return resolveWithDefault( + logger, + "NVIDIA Container Toolkit CLI", + nvidiaCTKPath, + nvidiaCTKDefaultFilePath, + ) +} + +// ResolveNVIDIAContainerRuntimeHookPath resolves the path the nvidia-container-runtime-hook binary. +func ResolveNVIDIAContainerRuntimeHookPath(logger *logrus.Logger, nvidiaContainerRuntimeHookPath string) string { + return resolveWithDefault( + logger, + "NVIDIA Container Runtime Hook", + nvidiaContainerRuntimeHookPath, + nvidiaContainerRuntimeHookDefaultPath, + ) +} + +// resolveWithDefault resolves the path to the specified binary. +// If an absolute path is specified, it is used directly without searching for the binary. +// If the binary cannot be found in the path, the specified default is used instead. +func resolveWithDefault(logger *logrus.Logger, label string, path string, defaultPath string) string { + if filepath.IsAbs(path) { + logger.Debugf("Using specified %v path %v", label, path) + return path + } + + if path == "" { + path = filepath.Base(defaultPath) + } + logger.Debugf("Locating %v as %v", label, path) + lookup := lookup.NewExecutableLocator(logger, "") + + resolvedPath := defaultPath + targets, err := lookup.Locate(path) + if err != nil { + logger.Warnf("Failed to locate %v: %v", path, err) + } else { + logger.Debugf("Found %v candidates: %v", path, targets) + resolvedPath = targets[0] + } + logger.Debugf("Using %v path %v", label, path) + + return resolvedPath } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 72f03336..867edf79 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -76,6 +76,9 @@ func TestGetConfig(t *testing.T) { }, }, }, + NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{ + Path: "nvidia-container-runtime-hook", + }, NVIDIACTKConfig: CTKConfig{ Path: "nvidia-ctk", }, @@ -95,6 +98,7 @@ func TestGetConfig(t *testing.T) { "nvidia-container-runtime.modes.cdi.default-kind = \"example.vendor.com/device\"", "nvidia-container-runtime.modes.cdi.annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]", "nvidia-container-runtime.modes.csv.mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"", + "nvidia-container-runtime-hook.path = \"/foo/bar/nvidia-container-runtime-hook\"", "nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"", }, expectedConfig: &Config{ @@ -120,6 +124,9 @@ func TestGetConfig(t *testing.T) { }, }, }, + NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{ + Path: "/foo/bar/nvidia-container-runtime-hook", + }, NVIDIACTKConfig: CTKConfig{ Path: "/foo/bar/nvidia-ctk", }, @@ -143,6 +150,8 @@ func TestGetConfig(t *testing.T) { "annotation-prefixes = [\"cdi.k8s.io/\", \"example.vendor.com/\",]", "[nvidia-container-runtime.modes.csv]", "mount-spec-path = \"/not/etc/nvidia-container-runtime/host-files-for-container.d\"", + "[nvidia-container-runtime-hook]", + "path = \"/foo/bar/nvidia-container-runtime-hook\"", "[nvidia-ctk]", "path = \"/foo/bar/nvidia-ctk\"", }, @@ -169,6 +178,9 @@ func TestGetConfig(t *testing.T) { }, }, }, + NVIDIAContainerRuntimeHookConfig: RuntimeHookConfig{ + Path: "/foo/bar/nvidia-container-runtime-hook", + }, NVIDIACTKConfig: CTKConfig{ Path: "/foo/bar/nvidia-ctk", }, diff --git a/internal/config/hook.go b/internal/config/hook.go index 5a3b27dc..1222a4bb 100644 --- a/internal/config/hook.go +++ b/internal/config/hook.go @@ -18,6 +18,9 @@ package config // RuntimeHookConfig stores the config options for the NVIDIA Container Runtime type RuntimeHookConfig struct { + // Path specifies the path to the NVIDIA Container Runtime hook binary. + // If an executable name is specified, this will be resolved in the path. + Path string `toml:"path"` // SkipModeDetection disables the mode check for the runtime hook. SkipModeDetection bool `toml:"skip-mode-detection"` } diff --git a/internal/modifier/stable.go b/internal/modifier/stable.go index 1b4d8401..8281ce89 100644 --- a/internal/modifier/stable.go +++ b/internal/modifier/stable.go @@ -17,10 +17,8 @@ package modifier import ( - "fmt" + "path/filepath" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" @@ -28,8 +26,11 @@ import ( // NewStableRuntimeModifier creates an OCI spec modifier that inserts the NVIDIA Container Runtime Hook into an OCI // spec. The specified logger is used to capture log output. -func NewStableRuntimeModifier(logger *logrus.Logger) oci.SpecModifier { - m := stableRuntimeModifier{logger: logger} +func NewStableRuntimeModifier(logger *logrus.Logger, nvidiaContainerRuntimeHookPath string) oci.SpecModifier { + m := stableRuntimeModifier{ + logger: logger, + nvidiaContainerRuntimeHookPath: nvidiaContainerRuntimeHookPath, + } return &m } @@ -37,7 +38,8 @@ func NewStableRuntimeModifier(logger *logrus.Logger) oci.SpecModifier { // stableRuntimeModifier modifies an OCI spec inplace, inserting the nvidia-container-runtime-hook as a // prestart hook. If the hook is already present, no modification is made. type stableRuntimeModifier struct { - logger *logrus.Logger + logger *logrus.Logger + nvidiaContainerRuntimeHookPath string } // Modify applies the required modification to the incoming OCI spec, inserting the nvidia-container-runtime-hook @@ -53,18 +55,9 @@ func (m stableRuntimeModifier) Modify(spec *specs.Spec) error { } } - // We create a locator and look for the NVIDIA Container Runtime Hook in the path. - candidates, err := lookup.NewExecutableLocator(m.logger, "").Locate(config.NVIDIAContainerRuntimeHookExecutable) - if err != nil { - return fmt.Errorf("failed to locate NVIDIA Container Runtime Hook: %v", err) - } - path := candidates[0] - if len(candidates) > 1 { - m.logger.Debugf("Using %v from multiple NVIDIA Container Runtime Hook candidates: %v", path, candidates) - } - + path := m.nvidiaContainerRuntimeHookPath m.logger.Infof("Using prestart hook path: %v", path) - args := []string{path} + args := []string{filepath.Base(path)} if spec.Hooks == nil { spec.Hooks = &specs.Hooks{} } diff --git a/internal/modifier/stable_test.go b/internal/modifier/stable_test.go index ac05c199..0b8eaff9 100644 --- a/internal/modifier/stable_test.go +++ b/internal/modifier/stable_test.go @@ -79,7 +79,7 @@ func TestAddHookModifier(t *testing.T) { Prestart: []specs.Hook{ { Path: testHookPath, - Args: []string{testHookPath, "prestart"}, + Args: []string{"nvidia-container-runtime-hook", "prestart"}, }, }, }, @@ -95,7 +95,7 @@ func TestAddHookModifier(t *testing.T) { Prestart: []specs.Hook{ { Path: testHookPath, - Args: []string{testHookPath, "prestart"}, + Args: []string{"nvidia-container-runtime-hook", "prestart"}, }, }, }, @@ -141,7 +141,7 @@ func TestAddHookModifier(t *testing.T) { }, { Path: testHookPath, - Args: []string{testHookPath, "prestart"}, + Args: []string{"nvidia-container-runtime-hook", "prestart"}, }, }, }, @@ -154,7 +154,7 @@ func TestAddHookModifier(t *testing.T) { t.Run(tc.description, func(t *testing.T) { - m := NewStableRuntimeModifier(logger) + m := NewStableRuntimeModifier(logger, testHookPath) err := m.Modify(&tc.spec) if tc.expectedError != nil { diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 47a80e29..e659694d 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -65,6 +65,7 @@ func (r rt) Run(argv []string) (rerr error) { cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride } cfg.NVIDIACTKConfig.Path = config.ResolveNVIDIACTKPath(r.logger.Logger, cfg.NVIDIACTKConfig.Path) + cfg.NVIDIAContainerRuntimeHookConfig.Path = config.ResolveNVIDIAContainerRuntimeHookPath(r.logger.Logger, cfg.NVIDIAContainerRuntimeHookConfig.Path) // Print the config to the output. configJSON, err := json.MarshalIndent(cfg, "", " ") diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index 36ac4e4e..80f94abc 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -93,7 +93,7 @@ func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec func newModeModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) { switch info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) { case "legacy": - return modifier.NewStableRuntimeModifier(logger), nil + return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil case "csv": return modifier.NewCSVModifier(logger, cfg, ociSpec) case "cdi":