Merge branch 'fix-multiple-driver-roots-wsl' into 'main'

Fix bug with multiple driver store paths

See merge request nvidia/container-toolkit/container-toolkit!425
This commit is contained in:
Evan Lezar 2023-06-27 14:15:38 +00:00 committed by Evan Lezar
parent 9d31bd4cc3
commit f677245d60
3 changed files with 205 additions and 7 deletions

View File

@ -3,6 +3,7 @@
## v1.13.3
* Generate CDI specification files with `644` permissions to allow rootless applications (e.g. podman).
* Fix bug causing incorrect nvidia-smi symlink to be created on WSL2 systems with multiple driver roots.
* [toolkit-container] Allow same envars for all runtime configs

View File

@ -67,7 +67,6 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
if len(searchPaths) > 1 {
logger.Warnf("Found multiple driver store paths: %v", searchPaths)
}
driverStorePath := searchPaths[0]
searchPaths = append(searchPaths, "/usr/lib/wsl/lib")
libraries := discover.NewMounts(
@ -83,12 +82,11 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
requiredDriverStoreFiles,
)
// On WSL2 the driver store location is used unchanged.
// For this reason we need to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the driver store.
target := filepath.Join(driverStorePath, "nvidia-smi")
link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links)
symlinkHook := nvidiaSMISimlinkHook{
logger: logger,
mountsFrom: libraries,
nvidiaCTKPath: nvidiaCTKPath,
}
cfg := &discover.Config{
DriverRoot: driverRoot,
@ -104,3 +102,39 @@ func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidi
return d, nil
}
type nvidiaSMISimlinkHook struct {
discover.None
logger *logrus.Logger
mountsFrom discover.Discover
nvidiaCTKPath string
}
// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store.
// On WSL2 the driver store location is used unchanged, for this reason we need
// to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the
// driver store.
func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
mounts, err := m.mountsFrom.Mounts()
if err != nil {
return nil, fmt.Errorf("failed to discover mounts: %w", err)
}
var target string
for _, mount := range mounts {
if filepath.Base(mount.Path) == "nvidia-smi" {
target = mount.Path
break
}
}
if target == "" {
m.logger.Warningf("Failed to find nvidia-smi in mounts: %v", mounts)
return nil, nil
}
link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCTKPath, links)
return symlinkHook.Hooks()
}

View File

@ -0,0 +1,163 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvcdi
import (
"errors"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/stretchr/testify/require"
testlog "github.com/sirupsen/logrus/hooks/test"
)
func TestNvidiaSMISymlinkHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
errMounts := errors.New("mounts error")
testCases := []struct {
description string
mounts discover.Discover
expectedError error
expectedHooks []discover.Hook
}{
{
description: "mounts error is returned",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
return nil, errMounts
},
},
expectedError: errMounts,
},
{
description: "no mounts returns no hooks",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
return nil, nil
},
},
},
{
description: "no nvidia-smi returns no hooks",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/not-nvidia-smi"},
{Path: "/also-not-nvidia-smi"},
}
return mounts, nil
},
},
},
{
description: "nvidia-smi must be in path",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/not-nvidia-smi", HostPath: "nvidia-smi"},
{Path: "/also-not-nvidia-smi", HostPath: "not-nvidia-smi"},
}
return mounts, nil
},
},
},
{
description: "nvidia-smi returns hook",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "nvidia-smi"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
{
description: "checks basename",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/some/path/nvidia-smi"},
{Path: "/nvidia-smi/but-not"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
{
description: "returns first match",
mounts: &discover.DiscoverMock{
MountsFunc: func() ([]discover.Mount, error) {
mounts := []discover.Mount{
{Path: "/some/path/nvidia-smi"},
{Path: "/another/path/nvidia-smi"},
}
return mounts, nil
},
},
expectedHooks: []discover.Hook{
{
Lifecycle: "createContainer",
Path: "nvidia-ctk",
Args: []string{"nvidia-ctk", "hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
m := nvidiaSMISimlinkHook{
logger: logger,
mountsFrom: tc.mounts,
nvidiaCTKPath: "nvidia-ctk",
}
devices, err := m.Devices()
require.NoError(t, err)
require.Empty(t, devices)
mounts, err := m.Mounts()
require.NoError(t, err)
require.Empty(t, mounts)
hooks, err := m.Hooks()
require.ErrorIs(t, err, tc.expectedError)
require.Equal(t, tc.expectedHooks, hooks)
})
}
}