diff --git a/internal/lookup/cuda/cuda.go b/internal/lookup/cuda/cuda.go new file mode 100644 index 00000000..98485a5a --- /dev/null +++ b/internal/lookup/cuda/cuda.go @@ -0,0 +1,102 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package cuda + +import ( + "path/filepath" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/sirupsen/logrus" +) + +type cudaLocator struct { + logger *logrus.Logger + driverRoot string +} + +// Options is a function that configures a cudaLocator. +type Options func(*cudaLocator) + +// WithLogger is an option that configures the logger used by the locator. +func WithLogger(logger *logrus.Logger) Options { + return func(c *cudaLocator) { + c.logger = logger + } +} + +// WithDriverRoot is an option that configures the driver root used by the locator. +func WithDriverRoot(driverRoot string) Options { + return func(c *cudaLocator) { + c.driverRoot = driverRoot + } +} + +// New creates a new CUDA library locator. +func New(opts ...Options) lookup.Locator { + c := &cudaLocator{} + for _, opt := range opts { + opt(c) + } + + if c.logger == nil { + c.logger = logrus.StandardLogger() + } + if c.driverRoot == "" { + c.driverRoot = "/" + } + + return c +} + +// Locate returns the path to the libcuda.so.RMVERSION file. +// libcuda.so is prefixed to the specified pattern. +func (l *cudaLocator) Locate(pattern string) ([]string, error) { + ldcacheLocator, err := lookup.NewLibraryLocator( + l.logger, + l.driverRoot, + ) + if err != nil { + l.logger.Debugf("Failed to create LDCache locator: %v", err) + } + + fullPattern := "libcuda.so" + pattern + + candidates, err := ldcacheLocator.Locate("libcuda.so") + if err == nil { + for _, c := range candidates { + if match, err := filepath.Match(fullPattern, filepath.Base(c)); err != nil || !match { + l.logger.Debugf("Skipping non-matching candidate %v: %v", c, err) + continue + } + return []string{c}, nil + } + } + l.logger.Debugf("Could not locate %q in LDCache: Checking predefined library paths.", pattern) + + pathLocator := lookup.NewFileLocator( + lookup.WithLogger(l.logger), + lookup.WithRoot(l.driverRoot), + lookup.WithSearchPaths( + "/usr/lib64", + "/usr/lib/x86_64-linux-gnu", + "/usr/lib/aarch64-linux-gnu", + ), + lookup.WithCount(1), + ) + + return pathLocator.Locate(fullPattern) +} diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index f623ae30..408da55a 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -23,6 +23,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) @@ -135,11 +136,10 @@ func NewDriverBinariesDiscoverer(logger *logrus.Logger, driverRoot string) disco func getVersionLibs(logger *logrus.Logger, driverRoot string, version string) ([]string, error) { logger.Infof("Using driver version %v", version) - l := cudaLocator{ - logger: logger, - driverRoot: driverRoot, - } - libCudaPaths, err := l.Locate("libcuda.so." + version) + libCudaPaths, err := cuda.New( + cuda.WithLogger(logger), + cuda.WithDriverRoot(driverRoot), + ).Locate("." + version) if err != nil { return nil, fmt.Errorf("failed to locate libcuda.so.%v: %v", version, err) } @@ -167,43 +167,3 @@ func getVersionLibs(logger *logrus.Logger, driverRoot string, version string) ([ return relative, nil } - -type cudaLocator struct { - logger *logrus.Logger - driverRoot string -} - -// Locate returns the path to the libcuda.so.RMVERSION file. -func (l *cudaLocator) Locate(pattern string) ([]string, error) { - ldcacheLocator, err := lookup.NewLibraryLocator( - l.logger, - l.driverRoot, - ) - if err != nil { - l.logger.Debugf("Failed to create LDCache locator: %v", err) - } - candidates, err := ldcacheLocator.Locate("libcuda.so") - if err == nil { - for _, c := range candidates { - if match, err := filepath.Match(pattern, filepath.Base(c)); err != nil || !match { - l.logger.Debugf("Skipping non-matching candidate %v: %v", c, err) - continue - } - return []string{c}, nil - } - } - l.logger.Debugf("Could not locate %q in LDCache: Checking predefined library paths.", pattern) - - pathLocator := lookup.NewFileLocator( - lookup.WithLogger(l.logger), - lookup.WithRoot(l.driverRoot), - lookup.WithSearchPaths( - "/usr/lib64", - "/usr/lib/x86_64-linux-gnu", - "/usr/lib/aarch64-linux-gnu", - ), - lookup.WithCount(1), - ) - - return pathLocator.Locate(pattern) -} diff --git a/pkg/nvcdi/management.go b/pkg/nvcdi/management.go index 305023ff..6643c559 100644 --- a/pkg/nvcdi/management.go +++ b/pkg/nvcdi/management.go @@ -23,6 +23,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/container-orchestrated-devices/container-device-interface/specs-go" @@ -84,12 +85,10 @@ func (m *managementlib) getCudaVersion() (string, error) { return version, nil } - l := cudaLocator{ - logger: m.logger, - driverRoot: m.driverRoot, - } - - libCudaPaths, err := l.Locate("libcuda.so.*.*.*") + libCudaPaths, err := cuda.New( + cuda.WithLogger(m.logger), + cuda.WithDriverRoot(m.driverRoot), + ).Locate(".*.*.*") if err != nil { return "", fmt.Errorf("failed to locate libcuda.so: %v", err) }