From 04b28d116c6d0715cc8ae1a7c7055265ba5bd1ce Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Fri, 3 Nov 2023 22:16:16 +0100 Subject: [PATCH] Make library lookups more robust These changes make library lookups more robust. The core change is that library lookups now first look a set of predefined locations before checking the ldcache. This also handles cases where an ldcache is not available more gracefully. Signed-off-by: Evan Lezar --- internal/lookup/cuda/cuda.go | 47 +++----------- internal/lookup/cuda/cuda_test.go | 102 ++++++++++++++++++++++++++++++ internal/lookup/file.go | 6 +- internal/lookup/library.go | 56 ++++++++++------ internal/lookup/merge.go | 53 ++++++++++++++++ 5 files changed, 207 insertions(+), 57 deletions(-) create mode 100644 internal/lookup/cuda/cuda_test.go create mode 100644 internal/lookup/merge.go diff --git a/internal/lookup/cuda/cuda.go b/internal/lookup/cuda/cuda.go index 100dbdf2..b95e81b4 100644 --- a/internal/lookup/cuda/cuda.go +++ b/internal/lookup/cuda/cuda.go @@ -17,13 +17,12 @@ package cuda import ( - "path/filepath" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" ) type cudaLocator struct { + lookup.Locator logger logger.Interface driverRoot string } @@ -59,46 +58,18 @@ func New(opts ...Options) lookup.Locator { c.driverRoot = "/" } + // TODO: Do we want to set the Count to 1 here? + l, _ := lookup.NewLibraryLocator( + c.logger, + c.driverRoot, + ) + + c.Locator = l return c } // Locate returns the path to the libcuda.so.RMVERSION file. // libcuda.so is prefixed to the specified pattern. func (l *cudaLocator) Locate(pattern string) ([]string, error) { - ldcacheLocator, err := lookup.NewLibraryLocator( - l.logger, - l.driverRoot, - ) - if err != nil { - l.logger.Debugf("Failed to create LDCache locator: %v", err) - } - - fullPattern := "libcuda.so" + pattern - - candidates, err := ldcacheLocator.Locate("libcuda.so") - if err == nil { - for _, c := range candidates { - if match, err := filepath.Match(fullPattern, filepath.Base(c)); err != nil || !match { - l.logger.Debugf("Skipping non-matching candidate %v: %v", c, err) - continue - } - return []string{c}, nil - } - } - l.logger.Debugf("Could not locate %q in LDCache: Checking predefined library paths.", pattern) - - pathLocator := lookup.NewFileLocator( - lookup.WithLogger(l.logger), - lookup.WithRoot(l.driverRoot), - lookup.WithSearchPaths( - "/usr/lib64", - "/usr/lib/x86_64-linux-gnu", - "/usr/lib/aarch64-linux-gnu", - "/usr/lib/x86_64-linux-gnu/nvidia/current", - "/usr/lib/aarch64-linux-gnu/nvidia/current", - ), - lookup.WithCount(1), - ) - - return pathLocator.Locate(fullPattern) + return l.Locator.Locate("libcuda.so" + pattern) } diff --git a/internal/lookup/cuda/cuda_test.go b/internal/lookup/cuda/cuda_test.go new file mode 100644 index 00000000..7151acdb --- /dev/null +++ b/internal/lookup/cuda/cuda_test.go @@ -0,0 +1,102 @@ +/** +# Copyright 2023 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cuda + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestLocate(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + testCases := []struct { + description string + libcudaPath string + expected []string + expectedError error + }{ + { + description: "no libcuda does not resolve library", + libcudaPath: "", + expected: []string{}, + expectedError: lookup.ErrNotFound, + }, + { + description: "no-ldcache searches /usr/lib64", + libcudaPath: "/usr/lib64/libcuda.so.123.34", + expected: []string{"/usr/lib64/libcuda.so.123.34"}, + expectedError: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + driverRoot, err := setupDriverRoot(t, tc.libcudaPath) + require.NoError(t, err) + + l := New( + WithLogger(logger), + WithDriverRoot(driverRoot), + ) + + candidates, err := l.Locate(".*") + require.ErrorIs(t, err, tc.expectedError) + + var strippedCandidates []string + for _, c := range candidates { + // NOTE: We need to strip `/private` on MacOs due to symlink resolution + strippedCandidates = append(strippedCandidates, strings.TrimPrefix(c, "/private")) + } + var expectedWithRoot []string + for _, e := range tc.expected { + expectedWithRoot = append(expectedWithRoot, filepath.Join(driverRoot, e)) + } + + require.EqualValues(t, expectedWithRoot, strippedCandidates) + }) + } +} + +// setupDriverRoot creates a folder that can be used to represent a driver root. +// The path to libcuda can be specified and an empty file is created at this location in the driver root. +func setupDriverRoot(t *testing.T, libCudaPath string) (string, error) { + driverRoot := t.TempDir() + + if libCudaPath == "" { + return driverRoot, nil + } + + if err := os.MkdirAll(filepath.Join(driverRoot, filepath.Dir(libCudaPath)), 0755); err != nil { + return "", fmt.Errorf("falied to create required driver root folder: %w", err) + } + + libCuda, err := os.Create(filepath.Join(driverRoot, libCudaPath)) + if err != nil { + return "", fmt.Errorf("failed to create dummy libcuda.so: %w", err) + } + defer libCuda.Close() + + return driverRoot, nil +} diff --git a/internal/lookup/file.go b/internal/lookup/file.go index d6fb5825..b5d2fa76 100644 --- a/internal/lookup/file.go +++ b/internal/lookup/file.go @@ -17,6 +17,7 @@ package lookup import ( + "errors" "fmt" "os" "path/filepath" @@ -24,6 +25,9 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) +// ErrNotFound indicates that a specified pattern or file could not be found. +var ErrNotFound = errors.New("not found") + // file can be used to locate file (or file-like elements) at a specified set of // prefixes. The validity of a file is determined by a filter function. type file struct { @@ -168,7 +172,7 @@ visit: } if !p.isOptional && len(filenames) == 0 { - return nil, fmt.Errorf("pattern %v not found", pattern) + return nil, fmt.Errorf("pattern %v %w", pattern, ErrNotFound) } return filenames, nil } diff --git a/internal/lookup/library.go b/internal/lookup/library.go index 0b5b7937..7bb62f68 100644 --- a/internal/lookup/library.go +++ b/internal/lookup/library.go @@ -18,44 +18,64 @@ package lookup import ( "fmt" - "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) -type library struct { - logger logger.Interface - symlink Locator - cache ldcache.LDCache +type ldcacheLocator struct { + logger logger.Interface + cache ldcache.LDCache } -var _ Locator = (*library)(nil) +var _ Locator = (*ldcacheLocator)(nil) // NewLibraryLocator creates a library locator using the specified logger. func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { + // We construct a symlink locator for expected library locations. + symlinkLocator := NewSymlinkLocator( + WithLogger(logger), + WithRoot(root), + WithSearchPaths([]string{ + "/", + "/usr/lib64", + "/usr/lib/x86_64-linux-gnu", + "/usr/lib/aarch64-linux-gnu", + "/usr/lib/x86_64-linux-gnu/nvidia/current", + "/usr/lib/aarch64-linux-gnu/nvidia/current", + "/lib64", + "/lib/x86_64-linux-gnu", + "/lib/aarch64-linux-gnu", + "/lib/x86_64-linux-gnu/nvidia/current", + "/lib/aarch64-linux-gnu/nvidia/current", + }...), + ) + + l := First( + symlinkLocator, + newLdcacheLocator(logger, root), + ) + return l, nil +} + +func newLdcacheLocator(logger logger.Interface, root string) Locator { cache, err := ldcache.New(logger, root) if err != nil { - return nil, fmt.Errorf("error loading ldcache: %v", err) + // If we failed to open the LDCache, we default to a symlink locator. + logger.Warningf("Failed to load ldcache: %v", err) + return nil } - l := library{ - logger: logger, - symlink: NewSymlinkLocator(WithLogger(logger), WithRoot(root)), - cache: cache, + return ldcacheLocator{ + logger: logger, + cache: cache, } - - return &l, nil } // Locate finds the specified libraryname. // If the input is a library name, the ldcache is searched otherwise the // provided path is resolved as a symlink. -func (l library) Locate(libname string) ([]string, error) { - if strings.Contains(libname, "/") { - return l.symlink.Locate(libname) - } - +func (l ldcacheLocator) Locate(libname string) ([]string, error) { paths32, paths64 := l.cache.Lookup(libname) if len(paths32) > 0 { l.logger.Warningf("Ignoring 32-bit libraries for %v: %v", libname, paths32) diff --git a/internal/lookup/merge.go b/internal/lookup/merge.go new file mode 100644 index 00000000..fa20b512 --- /dev/null +++ b/internal/lookup/merge.go @@ -0,0 +1,53 @@ +/** +# Copyright 2023 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package lookup + +import "errors" + +type first []Locator + +// First returns a locator that returns the first non-empty match +func First(locators ...Locator) Locator { + var f first + for _, l := range locators { + if l == nil { + continue + } + f = append(f, l) + } + return f +} + +// Locate returns the results for the first locator that returns a non-empty non-error result. +func (f first) Locate(pattern string) ([]string, error) { + var allErrors []error + for _, l := range f { + if l == nil { + continue + } + candidates, err := l.Locate(pattern) + if err != nil { + allErrors = append(allErrors, err) + continue + } + if len(candidates) > 0 { + return candidates, nil + } + } + + return nil, errors.Join(allErrors...) +}