diff --git a/pkg/nvcdi/driver-wsl.go b/pkg/nvcdi/driver-wsl.go index f5d138cf..1aa02b8e 100644 --- a/pkg/nvcdi/driver-wsl.go +++ b/pkg/nvcdi/driver-wsl.go @@ -67,7 +67,6 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi if len(searchPaths) > 1 { logger.Warningf("Found multiple driver store paths: %v", searchPaths) } - driverStorePath := searchPaths[0] searchPaths = append(searchPaths, "/usr/lib/wsl/lib") libraries := discover.NewMounts( @@ -83,12 +82,11 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi requiredDriverStoreFiles, ) - // On WSL2 the driver store location is used unchanged. - // For this reason we need to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the driver store. - target := filepath.Join(driverStorePath, "nvidia-smi") - link := "/usr/bin/nvidia-smi" - links := []string{fmt.Sprintf("%s::%s", target, link)} - symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links) + symlinkHook := nvidiaSMISimlinkHook{ + logger: logger, + mountsFrom: libraries, + nvidiaCTKPath: nvidiaCTKPath, + } ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCTKPath) @@ -100,3 +98,39 @@ func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, nvi return d, nil } + +type nvidiaSMISimlinkHook struct { + discover.None + logger logger.Interface + mountsFrom discover.Discover + nvidiaCTKPath string +} + +// Hooks returns a hook that creates a symlink to nvidia-smi in the driver store. +// On WSL2 the driver store location is used unchanged, for this reason we need +// to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the +// driver store. +func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) { + mounts, err := m.mountsFrom.Mounts() + if err != nil { + return nil, fmt.Errorf("failed to discover mounts: %w", err) + } + + var target string + for _, mount := range mounts { + if filepath.Base(mount.Path) == "nvidia-smi" { + target = mount.Path + break + } + } + + if target == "" { + m.logger.Warningf("Failed to find nvidia-smi in mounts: %v", mounts) + return nil, nil + } + link := "/usr/bin/nvidia-smi" + links := []string{fmt.Sprintf("%s::%s", target, link)} + symlinkHook := discover.CreateCreateSymlinkHook(m.nvidiaCTKPath, links) + + return symlinkHook.Hooks() +} diff --git a/pkg/nvcdi/driver-wsl_test.go b/pkg/nvcdi/driver-wsl_test.go new file mode 100644 index 00000000..39729130 --- /dev/null +++ b/pkg/nvcdi/driver-wsl_test.go @@ -0,0 +1,163 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "errors" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/stretchr/testify/require" + + testlog "github.com/sirupsen/logrus/hooks/test" +) + +func TestNvidiaSMISymlinkHook(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + errMounts := errors.New("mounts error") + + testCases := []struct { + description string + mounts discover.Discover + expectedError error + expectedHooks []discover.Hook + }{ + { + description: "mounts error is returned", + mounts: &discover.DiscoverMock{ + MountsFunc: func() ([]discover.Mount, error) { + return nil, errMounts + }, + }, + expectedError: errMounts, + }, + { + description: "no mounts returns no hooks", + mounts: &discover.DiscoverMock{ + MountsFunc: func() ([]discover.Mount, error) { + return nil, nil + }, + }, + }, + { + description: "no nvidia-smi returns no hooks", + mounts: &discover.DiscoverMock{ + MountsFunc: func() ([]discover.Mount, error) { + mounts := []discover.Mount{ + {Path: "/not-nvidia-smi"}, + {Path: "/also-not-nvidia-smi"}, + } + return mounts, nil + }, + }, + }, + { + description: "nvidia-smi must be in path", + mounts: &discover.DiscoverMock{ + MountsFunc: func() ([]discover.Mount, error) { + mounts := []discover.Mount{ + {Path: "/not-nvidia-smi", HostPath: "nvidia-smi"}, + {Path: "/also-not-nvidia-smi", HostPath: "not-nvidia-smi"}, + } + return mounts, nil + }, + }, + }, + { + description: "nvidia-smi returns hook", + mounts: &discover.DiscoverMock{ + MountsFunc: func() ([]discover.Mount, error) { + mounts := []discover.Mount{ + {Path: "nvidia-smi"}, + } + return mounts, nil + }, + }, + expectedHooks: []discover.Hook{ + { + Lifecycle: "createContainer", + Path: "nvidia-ctk", + Args: []string{"nvidia-ctk", "hook", "create-symlinks", + "--link", "nvidia-smi::/usr/bin/nvidia-smi"}, + }, + }, + }, + { + description: "checks basename", + mounts: &discover.DiscoverMock{ + MountsFunc: func() ([]discover.Mount, error) { + mounts := []discover.Mount{ + {Path: "/some/path/nvidia-smi"}, + {Path: "/nvidia-smi/but-not"}, + } + return mounts, nil + }, + }, + expectedHooks: []discover.Hook{ + { + Lifecycle: "createContainer", + Path: "nvidia-ctk", + Args: []string{"nvidia-ctk", "hook", "create-symlinks", + "--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"}, + }, + }, + }, + { + description: "returns first match", + mounts: &discover.DiscoverMock{ + MountsFunc: func() ([]discover.Mount, error) { + mounts := []discover.Mount{ + {Path: "/some/path/nvidia-smi"}, + {Path: "/another/path/nvidia-smi"}, + } + return mounts, nil + }, + }, + expectedHooks: []discover.Hook{ + { + Lifecycle: "createContainer", + Path: "nvidia-ctk", + Args: []string{"nvidia-ctk", "hook", "create-symlinks", + "--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"}, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + m := nvidiaSMISimlinkHook{ + logger: logger, + mountsFrom: tc.mounts, + nvidiaCTKPath: "nvidia-ctk", + } + + devices, err := m.Devices() + require.NoError(t, err) + require.Empty(t, devices) + + mounts, err := m.Mounts() + require.NoError(t, err) + require.Empty(t, mounts) + + hooks, err := m.Hooks() + require.ErrorIs(t, err, tc.expectedError) + require.Equal(t, tc.expectedHooks, hooks) + }) + } +}