diff --git a/internal/lookup/library.go b/internal/lookup/library.go index bdb2248e..7f5cf7c8 100644 --- a/internal/lookup/library.go +++ b/internal/lookup/library.go @@ -32,6 +32,17 @@ var _ Locator = (*ldcacheLocator)(nil) // NewLibraryLocator creates a library locator using the specified options. func NewLibraryLocator(opts ...Option) Locator { + b := newBuilder(opts...) + + // If search paths are already specified, we return a locator for the specified search paths. + if len(b.searchPaths) > 0 { + return NewSymlinkLocator( + WithLogger(b.logger), + WithSearchPaths(b.searchPaths...), + WithRoot("/"), + ) + } + opts = append(opts, WithSearchPaths([]string{ "/", diff --git a/internal/lookup/library_test.go b/internal/lookup/library_test.go index 21485dca..60dba77c 100644 --- a/internal/lookup/library_test.go +++ b/internal/lookup/library_test.go @@ -66,7 +66,7 @@ func TestLDCacheLocator(t *testing.T) { { description: "lib only not in LDCache returns error", libname: "libnotcuda.so", - expectedError: errNotFound, + expectedError: ErrNotFound, }, } @@ -127,16 +127,12 @@ func TestLibraryLocator(t *testing.T) { require.NoError(t, os.Symlink(libTarget1, source1)) require.NoError(t, os.Symlink(source1, source2)) - lut := NewLibraryLocator( - WithLogger(logger), - WithRoot(testDir), - ) - testCases := []struct { - description string - libname string - expected []string - expectedError error + description string + libname string + librarySearchPaths []string + expected []string + expectedError error }{ { description: "slash in path resoves symlink", @@ -156,7 +152,7 @@ func TestLibraryLocator(t *testing.T) { { description: "library not found returns error", libname: "/lib/symlink/libnotcuda.so", - expectedError: errNotFound, + expectedError: ErrNotFound, }, { description: "slash in path with pattern resoves symlink", @@ -176,10 +172,31 @@ func TestLibraryLocator(t *testing.T) { filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"), }, }, + { + description: "search paths are searched", + libname: "lib*.so.1.2.3", + librarySearchPaths: []string{filepath.Join(testDir, "/lib/symlink")}, + expected: []string{ + filepath.Join(testDir, "/lib/symlink/libcuda.so.1.2.3"), + filepath.Join(testDir, "/lib/symlink/libtarget.so.1.2.3"), + }, + }, + { + description: "search paths are absolute to root", + libname: "lib*.so.1.2.3", + librarySearchPaths: []string{"/lib/symlink"}, + expectedError: ErrNotFound, + }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + lut := NewLibraryLocator( + WithLogger(logger), + WithRoot(testDir), + WithSearchPaths(tc.librarySearchPaths...), + ) + candidates, err := lut.Locate(tc.libname) require.ErrorIs(t, err, tc.expectedError) diff --git a/internal/lookup/locator.go b/internal/lookup/locator.go index af5633fb..73ade232 100644 --- a/internal/lookup/locator.go +++ b/internal/lookup/locator.go @@ -25,4 +25,5 @@ type Locator interface { Locate(string) ([]string, error) } -var errNotFound = errors.New("not found") +// ErrNotFound indicates that a specified pattern or file could not be found. +var ErrNotFound = errors.New("not found")