diff --git a/internal/ldcache/ldcache.go b/internal/ldcache/ldcache.go index 2f6de2fe..4daf95bc 100644 --- a/internal/ldcache/ldcache.go +++ b/internal/ldcache/ldcache.go @@ -22,15 +22,12 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" "os" "path/filepath" - "strings" "syscall" "unsafe" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" ) const ldcachePath = "/etc/ld.so.cache" @@ -82,10 +79,9 @@ type entry2 struct { // LDCache represents the interface for performing lookups into the LDCache // -//go:generate moq -out ldcache_mock.go . LDCache +//go:generate moq -rm -out ldcache_mock.go . LDCache type LDCache interface { List() ([]string, []string) - Lookup(...string) ([]string, []string) } type ldcache struct { @@ -105,14 +101,7 @@ func New(logger logger.Interface, root string) (LDCache, error) { logger.Debugf("Opening ld.conf at %v", path) f, err := os.Open(path) - if os.IsNotExist(err) { - logger.Warningf("Could not find ld.so.cache at %v; creating empty cache", path) - e := &empty{ - logger: logger, - path: path, - } - return e, nil - } else if err != nil { + if err != nil { return nil, err } defer f.Close() @@ -196,7 +185,7 @@ type entry struct { } // getEntries returns the entires of the ldcache in a go-friendly struct. -func (c *ldcache) getEntries(selected func(string) bool) []entry { +func (c *ldcache) getEntries() []entry { var entries []entry for _, e := range c.entries { bits := 0 @@ -223,9 +212,6 @@ func (c *ldcache) getEntries(selected func(string) bool) []entry { c.logger.Debugf("Skipping invalid lib") continue } - if !selected(lib) { - continue - } value := bytesToString(c.libs[e.Value:]) if value == "" { c.logger.Debugf("Skipping invalid value for lib %v", lib) @@ -236,51 +222,19 @@ func (c *ldcache) getEntries(selected func(string) bool) []entry { bits: bits, value: value, } - entries = append(entries, e) } - return entries } // List creates a list of libraries in the ldcache. // The 32-bit and 64-bit libraries are returned separately. func (c *ldcache) List() ([]string, []string) { - all := func(s string) bool { return true } - - return c.resolveSelected(all) -} - -// Lookup searches the ldcache for the specified prefixes. -// The 32-bit and 64-bit libraries matching the prefixes are returned. -func (c *ldcache) Lookup(libPrefixes ...string) ([]string, []string) { - c.logger.Debugf("Looking up %v in cache", libPrefixes) - - // We define a functor to check whether a given library name matches any of the prefixes - matchesAnyPrefix := func(s string) bool { - for _, p := range libPrefixes { - if strings.HasPrefix(s, p) { - return true - } - } - return false - } - - return c.resolveSelected(matchesAnyPrefix) -} - -// resolveSelected process the entries in the LDCach based on the supplied filter and returns the resolved paths. -// The paths are separated by bittage. -func (c *ldcache) resolveSelected(selected func(string) bool) ([]string, []string) { paths := make(map[int][]string) processed := make(map[string]bool) - for _, e := range c.getEntries(selected) { - path, err := c.resolve(e.value) - if err != nil { - c.logger.Debugf("Could not resolve entry: %v", err) - continue - } + for _, e := range c.getEntries() { + path := filepath.Join(c.root, e.value) if processed[path] { continue } @@ -291,29 +245,6 @@ func (c *ldcache) resolveSelected(selected func(string) bool) ([]string, []strin return paths[32], paths[64] } -// resolve resolves the specified ldcache entry based on the value being processed. -// The input is the name of the entry in the cache. -func (c *ldcache) resolve(target string) (string, error) { - name := filepath.Join(c.root, target) - - c.logger.Debugf("checking %v", name) - - link, err := symlinks.Resolve(name) - if err != nil { - return "", fmt.Errorf("failed to resolve symlink: %v", err) - } - if link == name { - return name, nil - } - - // We return absolute paths for all targets - if !filepath.IsAbs(link) || strings.HasPrefix(link, ".") { - link = filepath.Join(filepath.Dir(target), link) - } - - return c.resolve(link) -} - // bytesToString converts a byte slice to a string. // This assumes that the byte slice is null-terminated func bytesToString(value []byte) string { diff --git a/internal/ldcache/ldcache_mock.go b/internal/ldcache/ldcache_mock.go index 092a3766..5aa53235 100644 --- a/internal/ldcache/ldcache_mock.go +++ b/internal/ldcache/ldcache_mock.go @@ -20,9 +20,6 @@ var _ LDCache = &LDCacheMock{} // ListFunc: func() ([]string, []string) { // panic("mock out the List method") // }, -// LookupFunc: func(strings ...string) ([]string, []string) { -// panic("mock out the Lookup method") -// }, // } // // // use mockedLDCache in code that requires LDCache @@ -33,22 +30,13 @@ type LDCacheMock struct { // ListFunc mocks the List method. ListFunc func() ([]string, []string) - // LookupFunc mocks the Lookup method. - LookupFunc func(strings ...string) ([]string, []string) - // calls tracks calls to the methods. calls struct { // List holds details about calls to the List method. List []struct { } - // Lookup holds details about calls to the Lookup method. - Lookup []struct { - // Strings is the strings argument value. - Strings []string - } } - lockList sync.RWMutex - lockLookup sync.RWMutex + lockList sync.RWMutex } // List calls ListFunc. @@ -77,35 +65,3 @@ func (mock *LDCacheMock) ListCalls() []struct { mock.lockList.RUnlock() return calls } - -// Lookup calls LookupFunc. -func (mock *LDCacheMock) Lookup(strings ...string) ([]string, []string) { - if mock.LookupFunc == nil { - panic("LDCacheMock.LookupFunc: method is nil but LDCache.Lookup was just called") - } - callInfo := struct { - Strings []string - }{ - Strings: strings, - } - mock.lockLookup.Lock() - mock.calls.Lookup = append(mock.calls.Lookup, callInfo) - mock.lockLookup.Unlock() - return mock.LookupFunc(strings...) -} - -// LookupCalls gets all the calls that were made to Lookup. -// Check the length with: -// -// len(mockedLDCache.LookupCalls()) -func (mock *LDCacheMock) LookupCalls() []struct { - Strings []string -} { - var calls []struct { - Strings []string - } - mock.lockLookup.RLock() - calls = mock.calls.Lookup - mock.lockLookup.RUnlock() - return calls -} diff --git a/internal/lookup/ldcache.go b/internal/lookup/ldcache.go new file mode 100644 index 00000000..677dafaa --- /dev/null +++ b/internal/lookup/ldcache.go @@ -0,0 +1,118 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package lookup + +import ( + "fmt" + "path/filepath" + "slices" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache" +) + +type ldcacheLocator struct { + *builder + resolvesTo map[string]string +} + +var _ Locator = (*ldcacheLocator)(nil) + +func NewLdcacheLocator(opts ...Option) Locator { + b := newBuilder(opts...) + + cache, err := ldcache.New(b.logger, b.root) + if err != nil { + b.logger.Warningf("Failed to load ldcache: %v", err) + if b.isOptional { + return &null{} + } + return ¬Found{} + } + + chain := NewSymlinkChainLocator(WithOptional(true)) + + resolvesTo := make(map[string]string) + _, libs64 := cache.List() + for _, library := range libs64 { + if _, processed := resolvesTo[library]; processed { + continue + } + candidates, err := chain.Locate(library) + if err != nil { + b.logger.Errorf("error processing library %s from ldcache: %v", library, err) + continue + } + + if len(candidates) == 0 { + resolvesTo[library] = library + continue + } + + // candidates represents a symlink chain. + // The first element represents the start of the chain and the last + // element the final target. + target := candidates[len(candidates)-1] + for _, candidate := range candidates { + resolvesTo[candidate] = target + } + } + + return &ldcacheLocator{ + builder: b, + resolvesTo: resolvesTo, + } +} + +// Locate finds the specified libraryname. +// If the input is a library name, the ldcache is searched otherwise the +// provided path is resolved as a symlink. +func (l ldcacheLocator) Locate(libname string) ([]string, error) { + var matcher func(string, string) bool + + if filepath.IsAbs(libname) { + matcher = func(p string, c string) bool { + m, _ := filepath.Match(filepath.Join(l.root, p), c) + return m + } + } else { + matcher = func(p string, c string) bool { + m, _ := filepath.Match(p, filepath.Base(c)) + return m + } + } + + var matches []string + seen := make(map[string]bool) + for name, target := range l.resolvesTo { + if !matcher(libname, name) { + continue + } + if seen[target] { + continue + } + seen[target] = true + matches = append(matches, target) + } + + slices.Sort(matches) + + if len(matches) == 0 && !l.isOptional { + return nil, fmt.Errorf("%s: %w", libname, ErrNotFound) + } + + return matches, nil +} diff --git a/internal/lookup/ldcache_test.go b/internal/lookup/ldcache_test.go index 4a40c0ef..66701a51 100644 --- a/internal/lookup/ldcache_test.go +++ b/internal/lookup/ldcache_test.go @@ -28,22 +28,40 @@ func TestLDCacheLookup(t *testing.T) { expectedError: ErrNotFound, }, { - rootFs: "rootfs-1", - inputs: []string{"libcuda.so.1", "libcuda.so.*", "libcuda.so.*.*", "libcuda.so.999.88.77"}, + rootFs: "rootfs-1", + inputs: []string{ + "libcuda.so.1", + "libcuda.so.*", + "libcuda.so.*.*", + "libcuda.so.999.88.77", + "/lib/x86_64-linux-gnu/libcuda.so.1", + "/lib/x86_64-linux-gnu/libcuda.so.*", + "/lib/x86_64-linux-gnu/libcuda.so.*.*", + "/lib/x86_64-linux-gnu/libcuda.so.999.88.77", + }, expected: "/lib/x86_64-linux-gnu/libcuda.so.999.88.77", }, { - rootFs: "rootfs-2", - inputs: []string{"libcuda.so.1", "libcuda.so.*", "libcuda.so.*.*", "libcuda.so.999.88.77"}, + rootFs: "rootfs-2", + inputs: []string{ + "libcuda.so.1", + "libcuda.so.*", + "libcuda.so.*.*", + "libcuda.so.999.88.77", + "/var/lib/nvidia/lib64/libcuda.so.1", + "/var/lib/nvidia/lib64/libcuda.so.*", + "/var/lib/nvidia/lib64/libcuda.so.*.*", + "/var/lib/nvidia/lib64/libcuda.so.999.88.77", + }, expected: "/var/lib/nvidia/lib64/libcuda.so.999.88.77", }, } for _, tc := range testCases { for _, input := range tc.inputs { - t.Run(tc.rootFs+input, func(t *testing.T) { + t.Run(tc.rootFs+" "+input, func(t *testing.T) { rootfs := filepath.Join(moduleRoot, "testdata", "lookup", tc.rootFs) - l := newLdcacheLocator( + l := NewLdcacheLocator( WithLogger(logger), WithRoot(rootfs), ) diff --git a/internal/lookup/library.go b/internal/lookup/library.go index 7f5cf7c8..6c403d08 100644 --- a/internal/lookup/library.go +++ b/internal/lookup/library.go @@ -16,20 +16,6 @@ package lookup -import ( - "fmt" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" -) - -type ldcacheLocator struct { - logger logger.Interface - cache ldcache.LDCache -} - -var _ Locator = (*ldcacheLocator)(nil) - // NewLibraryLocator creates a library locator using the specified options. func NewLibraryLocator(opts ...Option) Locator { b := newBuilder(opts...) @@ -63,39 +49,7 @@ func NewLibraryLocator(opts ...Option) Locator { l := First( symlinkLocator, - newLdcacheLocator(opts...), + NewLdcacheLocator(opts...), ) return l } - -func newLdcacheLocator(opts ...Option) Locator { - b := newBuilder(opts...) - - cache, err := ldcache.New(b.logger, b.root) - if err != nil { - // If we failed to open the LDCache, we default to a symlink locator. - b.logger.Warningf("Failed to load ldcache: %v", err) - return nil - } - - return &ldcacheLocator{ - logger: b.logger, - cache: cache, - } -} - -// Locate finds the specified libraryname. -// If the input is a library name, the ldcache is searched otherwise the -// provided path is resolved as a symlink. -func (l ldcacheLocator) Locate(libname string) ([]string, error) { - paths32, paths64 := l.cache.Lookup(libname) - if len(paths32) > 0 { - l.logger.Warningf("Ignoring 32-bit libraries for %v: %v", libname, paths32) - } - - if len(paths64) == 0 { - return nil, fmt.Errorf("64-bit library %v: %w", libname, ErrNotFound) - } - - return paths64, nil -} diff --git a/internal/lookup/library_test.go b/internal/lookup/library_test.go index dd686b75..8837ae99 100644 --- a/internal/lookup/library_test.go +++ b/internal/lookup/library_test.go @@ -24,82 +24,8 @@ import ( testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache" ) -func TestLDCacheLocator(t *testing.T) { - logger, _ := testlog.NewNullLogger() - - testDir := t.TempDir() - symlinkDir := filepath.Join(testDir, "/lib/symlink") - require.NoError(t, os.MkdirAll(symlinkDir, 0755)) - - versionLib := filepath.Join(symlinkDir, "libcuda.so.1.2.3") - soLink := filepath.Join(symlinkDir, "libcuda.so") - sonameLink := filepath.Join(symlinkDir, "libcuda.so.1") - - _, err := os.Create(versionLib) - require.NoError(t, err) - require.NoError(t, os.Symlink(versionLib, sonameLink)) - require.NoError(t, os.Symlink(sonameLink, soLink)) - - lut := newLdcacheLocator( - WithLogger(logger), - WithRoot(testDir), - ) - - testCases := []struct { - description string - libname string - ldcacheMap map[string]string - expected []string - expectedError error - }{ - { - description: "lib only resolves in LDCache", - libname: "libcuda.so", - ldcacheMap: map[string]string{ - "libcuda.so": "/lib/from/ldcache/libcuda.so.4.5.6", - }, - expected: []string{"/lib/from/ldcache/libcuda.so.4.5.6"}, - }, - { - description: "lib only not in LDCache returns error", - libname: "libnotcuda.so", - expectedError: ErrNotFound, - }, - } - - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - // We override the LDCache with a mock implementation - l := lut.(*ldcacheLocator) - l.cache = &ldcache.LDCacheMock{ - LookupFunc: func(strings ...string) ([]string, []string) { - var result []string - for _, s := range strings { - if v, ok := tc.ldcacheMap[s]; ok { - result = append(result, v) - } - } - return nil, result - }, - } - - candidates, err := lut.Locate(tc.libname) - require.ErrorIs(t, err, tc.expectedError) - - var cleanedCandidates []string - for _, c := range candidates { - // On MacOS /var and /tmp symlink to /private/var and /private/tmp which is included in the resolved path. - cleanedCandidates = append(cleanedCandidates, strings.TrimPrefix(c, "/private")) - } - require.EqualValues(t, tc.expected, cleanedCandidates) - }) - } -} - func TestLibraryLocator(t *testing.T) { logger, _ := testlog.NewNullLogger() diff --git a/internal/ldcache/empty.go b/internal/lookup/null.go similarity index 51% rename from internal/ldcache/empty.go rename to internal/lookup/null.go index 30d3f4c8..938e481b 100644 --- a/internal/ldcache/empty.go +++ b/internal/lookup/null.go @@ -1,5 +1,5 @@ /** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,24 +14,23 @@ # limitations under the License. **/ -package ldcache +package lookup -import "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" +import "fmt" -type empty struct { - logger logger.Interface - path string +// A null locator always returns an empty response. +type null struct { } -var _ LDCache = (*empty)(nil) - -// List always returns nil for an empty ldcache -func (e *empty) List() ([]string, []string) { +// Locate always returns empty for a null locator. +func (l *null) Locate(string) ([]string, error) { return nil, nil } -// Lookup logs a debug message and returns nil for an empty ldcache -func (e *empty) Lookup(prefixes ...string) ([]string, []string) { - e.logger.Debugf("Calling Lookup(%v) on empty ldcache: %v", prefixes, e.path) - return nil, nil +// A notFound locator always returns an ErrNotFound error. +type notFound struct { +} + +func (l *notFound) Locate(s string) ([]string, error) { + return nil, fmt.Errorf("%s: %w", s, ErrNotFound) }