mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Merge branch 'improve-library-lookup' into 'main'
Make CDI-based library discovery more robust See merge request nvidia/container-toolkit/container-toolkit!488
This commit is contained in:
		
						commit
						da0755769f
					
				| @ -17,13 +17,12 @@ | |||||||
| package cuda | package cuda | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"path/filepath" |  | ||||||
| 
 |  | ||||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" | 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" | ||||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" | 	"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type cudaLocator struct { | type cudaLocator struct { | ||||||
|  | 	lookup.Locator | ||||||
| 	logger     logger.Interface | 	logger     logger.Interface | ||||||
| 	driverRoot string | 	driverRoot string | ||||||
| } | } | ||||||
| @ -59,46 +58,18 @@ func New(opts ...Options) lookup.Locator { | |||||||
| 		c.driverRoot = "/" | 		c.driverRoot = "/" | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// TODO: Do we want to set the Count to 1 here?
 | ||||||
|  | 	l, _ := lookup.NewLibraryLocator( | ||||||
|  | 		c.logger, | ||||||
|  | 		c.driverRoot, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	c.Locator = l | ||||||
| 	return c | 	return c | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Locate returns the path to the libcuda.so.RMVERSION file.
 | // Locate returns the path to the libcuda.so.RMVERSION file.
 | ||||||
| // libcuda.so is prefixed to the specified pattern.
 | // libcuda.so is prefixed to the specified pattern.
 | ||||||
| func (l *cudaLocator) Locate(pattern string) ([]string, error) { | func (l *cudaLocator) Locate(pattern string) ([]string, error) { | ||||||
| 	ldcacheLocator, err := lookup.NewLibraryLocator( | 	return l.Locator.Locate("libcuda.so" + pattern) | ||||||
| 		l.logger, |  | ||||||
| 		l.driverRoot, |  | ||||||
| 	) |  | ||||||
| 	if err != nil { |  | ||||||
| 		l.logger.Debugf("Failed to create LDCache locator: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	fullPattern := "libcuda.so" + pattern |  | ||||||
| 
 |  | ||||||
| 	candidates, err := ldcacheLocator.Locate("libcuda.so") |  | ||||||
| 	if err == nil { |  | ||||||
| 		for _, c := range candidates { |  | ||||||
| 			if match, err := filepath.Match(fullPattern, filepath.Base(c)); err != nil || !match { |  | ||||||
| 				l.logger.Debugf("Skipping non-matching candidate %v: %v", c, err) |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			return []string{c}, nil |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	l.logger.Debugf("Could not locate %q in LDCache: Checking predefined library paths.", pattern) |  | ||||||
| 
 |  | ||||||
| 	pathLocator := lookup.NewFileLocator( |  | ||||||
| 		lookup.WithLogger(l.logger), |  | ||||||
| 		lookup.WithRoot(l.driverRoot), |  | ||||||
| 		lookup.WithSearchPaths( |  | ||||||
| 			"/usr/lib64", |  | ||||||
| 			"/usr/lib/x86_64-linux-gnu", |  | ||||||
| 			"/usr/lib/aarch64-linux-gnu", |  | ||||||
| 			"/usr/lib/x86_64-linux-gnu/nvidia/current", |  | ||||||
| 			"/usr/lib/aarch64-linux-gnu/nvidia/current", |  | ||||||
| 		), |  | ||||||
| 		lookup.WithCount(1), |  | ||||||
| 	) |  | ||||||
| 
 |  | ||||||
| 	return pathLocator.Locate(fullPattern) |  | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										102
									
								
								internal/lookup/cuda/cuda_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								internal/lookup/cuda/cuda_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,102 @@ | |||||||
|  | /** | ||||||
|  | # Copyright 2023 NVIDIA CORPORATION | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | **/ | ||||||
|  | 
 | ||||||
|  | package cuda | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" | ||||||
|  | 	testlog "github.com/sirupsen/logrus/hooks/test" | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestLocate(t *testing.T) { | ||||||
|  | 	logger, _ := testlog.NewNullLogger() | ||||||
|  | 
 | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		description   string | ||||||
|  | 		libcudaPath   string | ||||||
|  | 		expected      []string | ||||||
|  | 		expectedError error | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			description:   "no libcuda does not resolve library", | ||||||
|  | 			libcudaPath:   "", | ||||||
|  | 			expected:      []string{}, | ||||||
|  | 			expectedError: lookup.ErrNotFound, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description:   "no-ldcache searches /usr/lib64", | ||||||
|  | 			libcudaPath:   "/usr/lib64/libcuda.so.123.34", | ||||||
|  | 			expected:      []string{"/usr/lib64/libcuda.so.123.34"}, | ||||||
|  | 			expectedError: nil, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range testCases { | ||||||
|  | 		t.Run(tc.description, func(t *testing.T) { | ||||||
|  | 			driverRoot, err := setupDriverRoot(t, tc.libcudaPath) | ||||||
|  | 			require.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 			l := New( | ||||||
|  | 				WithLogger(logger), | ||||||
|  | 				WithDriverRoot(driverRoot), | ||||||
|  | 			) | ||||||
|  | 
 | ||||||
|  | 			candidates, err := l.Locate(".*") | ||||||
|  | 			require.ErrorIs(t, err, tc.expectedError) | ||||||
|  | 
 | ||||||
|  | 			var strippedCandidates []string | ||||||
|  | 			for _, c := range candidates { | ||||||
|  | 				// NOTE: We need to strip `/private` on MacOs due to symlink resolution
 | ||||||
|  | 				strippedCandidates = append(strippedCandidates, strings.TrimPrefix(c, "/private")) | ||||||
|  | 			} | ||||||
|  | 			var expectedWithRoot []string | ||||||
|  | 			for _, e := range tc.expected { | ||||||
|  | 				expectedWithRoot = append(expectedWithRoot, filepath.Join(driverRoot, e)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			require.EqualValues(t, expectedWithRoot, strippedCandidates) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // setupDriverRoot creates a folder that can be used to represent a driver root.
 | ||||||
|  | // The path to libcuda can be specified and an empty file is created at this location in the driver root.
 | ||||||
|  | func setupDriverRoot(t *testing.T, libCudaPath string) (string, error) { | ||||||
|  | 	driverRoot := t.TempDir() | ||||||
|  | 
 | ||||||
|  | 	if libCudaPath == "" { | ||||||
|  | 		return driverRoot, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := os.MkdirAll(filepath.Join(driverRoot, filepath.Dir(libCudaPath)), 0755); err != nil { | ||||||
|  | 		return "", fmt.Errorf("falied to create required driver root folder: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	libCuda, err := os.Create(filepath.Join(driverRoot, libCudaPath)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("failed to create dummy libcuda.so: %w", err) | ||||||
|  | 	} | ||||||
|  | 	defer libCuda.Close() | ||||||
|  | 
 | ||||||
|  | 	return driverRoot, nil | ||||||
|  | } | ||||||
| @ -17,6 +17,7 @@ | |||||||
| package lookup | package lookup | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| @ -24,6 +25,9 @@ import ( | |||||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" | 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // ErrNotFound indicates that a specified pattern or file could not be found.
 | ||||||
|  | var ErrNotFound = errors.New("not found") | ||||||
|  | 
 | ||||||
| // file can be used to locate file (or file-like elements) at a specified set of
 | // file can be used to locate file (or file-like elements) at a specified set of
 | ||||||
| // prefixes. The validity of a file is determined by a filter function.
 | // prefixes. The validity of a file is determined by a filter function.
 | ||||||
| type file struct { | type file struct { | ||||||
| @ -168,7 +172,7 @@ visit: | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !p.isOptional && len(filenames) == 0 { | 	if !p.isOptional && len(filenames) == 0 { | ||||||
| 		return nil, fmt.Errorf("pattern %v not found", pattern) | 		return nil, fmt.Errorf("pattern %v %w", pattern, ErrNotFound) | ||||||
| 	} | 	} | ||||||
| 	return filenames, nil | 	return filenames, nil | ||||||
| } | } | ||||||
|  | |||||||
| @ -18,44 +18,64 @@ package lookup | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache" | 	"github.com/NVIDIA/nvidia-container-toolkit/internal/ldcache" | ||||||
| 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" | 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type library struct { | type ldcacheLocator struct { | ||||||
| 	logger  logger.Interface | 	logger logger.Interface | ||||||
| 	symlink Locator | 	cache  ldcache.LDCache | ||||||
| 	cache   ldcache.LDCache |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| var _ Locator = (*library)(nil) | var _ Locator = (*ldcacheLocator)(nil) | ||||||
| 
 | 
 | ||||||
| // NewLibraryLocator creates a library locator using the specified logger.
 | // NewLibraryLocator creates a library locator using the specified logger.
 | ||||||
| func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { | func NewLibraryLocator(logger logger.Interface, root string) (Locator, error) { | ||||||
|  | 	// We construct a symlink locator for expected library locations.
 | ||||||
|  | 	symlinkLocator := NewSymlinkLocator( | ||||||
|  | 		WithLogger(logger), | ||||||
|  | 		WithRoot(root), | ||||||
|  | 		WithSearchPaths([]string{ | ||||||
|  | 			"/", | ||||||
|  | 			"/usr/lib64", | ||||||
|  | 			"/usr/lib/x86_64-linux-gnu", | ||||||
|  | 			"/usr/lib/aarch64-linux-gnu", | ||||||
|  | 			"/usr/lib/x86_64-linux-gnu/nvidia/current", | ||||||
|  | 			"/usr/lib/aarch64-linux-gnu/nvidia/current", | ||||||
|  | 			"/lib64", | ||||||
|  | 			"/lib/x86_64-linux-gnu", | ||||||
|  | 			"/lib/aarch64-linux-gnu", | ||||||
|  | 			"/lib/x86_64-linux-gnu/nvidia/current", | ||||||
|  | 			"/lib/aarch64-linux-gnu/nvidia/current", | ||||||
|  | 		}...), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	l := First( | ||||||
|  | 		symlinkLocator, | ||||||
|  | 		newLdcacheLocator(logger, root), | ||||||
|  | 	) | ||||||
|  | 	return l, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newLdcacheLocator(logger logger.Interface, root string) Locator { | ||||||
| 	cache, err := ldcache.New(logger, root) | 	cache, err := ldcache.New(logger, root) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("error loading ldcache: %v", err) | 		// If we failed to open the LDCache, we default to a symlink locator.
 | ||||||
|  | 		logger.Warningf("Failed to load ldcache: %v", err) | ||||||
|  | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	l := library{ | 	return ldcacheLocator{ | ||||||
| 		logger:  logger, | 		logger: logger, | ||||||
| 		symlink: NewSymlinkLocator(WithLogger(logger), WithRoot(root)), | 		cache:  cache, | ||||||
| 		cache:   cache, |  | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	return &l, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Locate finds the specified libraryname.
 | // Locate finds the specified libraryname.
 | ||||||
| // If the input is a library name, the ldcache is searched otherwise the
 | // If the input is a library name, the ldcache is searched otherwise the
 | ||||||
| // provided path is resolved as a symlink.
 | // provided path is resolved as a symlink.
 | ||||||
| func (l library) Locate(libname string) ([]string, error) { | func (l ldcacheLocator) Locate(libname string) ([]string, error) { | ||||||
| 	if strings.Contains(libname, "/") { |  | ||||||
| 		return l.symlink.Locate(libname) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	paths32, paths64 := l.cache.Lookup(libname) | 	paths32, paths64 := l.cache.Lookup(libname) | ||||||
| 	if len(paths32) > 0 { | 	if len(paths32) > 0 { | ||||||
| 		l.logger.Warningf("Ignoring 32-bit libraries for %v: %v", libname, paths32) | 		l.logger.Warningf("Ignoring 32-bit libraries for %v: %v", libname, paths32) | ||||||
|  | |||||||
							
								
								
									
										53
									
								
								internal/lookup/merge.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								internal/lookup/merge.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | /** | ||||||
|  | # Copyright 2023 NVIDIA CORPORATION | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | **/ | ||||||
|  | 
 | ||||||
|  | package lookup | ||||||
|  | 
 | ||||||
|  | import "errors" | ||||||
|  | 
 | ||||||
|  | type first []Locator | ||||||
|  | 
 | ||||||
|  | // First returns a locator that returns the first non-empty match
 | ||||||
|  | func First(locators ...Locator) Locator { | ||||||
|  | 	var f first | ||||||
|  | 	for _, l := range locators { | ||||||
|  | 		if l == nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		f = append(f, l) | ||||||
|  | 	} | ||||||
|  | 	return f | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Locate returns the results for the first locator that returns a non-empty non-error result.
 | ||||||
|  | func (f first) Locate(pattern string) ([]string, error) { | ||||||
|  | 	var allErrors []error | ||||||
|  | 	for _, l := range f { | ||||||
|  | 		if l == nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		candidates, err := l.Locate(pattern) | ||||||
|  | 		if err != nil { | ||||||
|  | 			allErrors = append(allErrors, err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if len(candidates) > 0 { | ||||||
|  | 			return candidates, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil, errors.Join(allErrors...) | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user