/** # Copyright 2024 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. **/ package discover import ( "fmt" "path/filepath" ) type additionalSymlinks struct { Discover version string nvidiaCDIHookPath string } // WithDriverDotSoSymlinks decorates the provided discoverer. // A hook is added that checks for specific driver symlinks that need to be created. func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover { if version == "" { version = "*.*" } return &additionalSymlinks{ Discover: mounts, nvidiaCDIHookPath: nvidiaCDIHookPath, version: version, } } // Hooks returns a hook to create the additional symlinks based on the mounts. func (d *additionalSymlinks) Hooks() ([]Hook, error) { mounts, err := d.Discover.Mounts() if err != nil { return nil, fmt.Errorf("failed to get library mounts: %v", err) } hooks, err := d.Discover.Hooks() if err != nil { return nil, fmt.Errorf("failed to get hooks: %v", err) } var links []string processedPaths := make(map[string]bool) processedLinks := make(map[string]bool) for _, mount := range mounts { if processedPaths[mount.Path] { continue } processedPaths[mount.Path] = true for _, link := range d.getLinksForMount(mount.Path) { if processedLinks[link] { continue } processedLinks[link] = true links = append(links, link) } } if len(links) == 0 { return hooks, nil } hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook) return append(hooks, hook), nil } // getLinksForMount maps the path to created links if any. func (d additionalSymlinks) getLinksForMount(path string) []string { dir, filename := filepath.Split(path) switch { case d.isDriverLibrary("libcuda.so", filename): // XXX Many applications wrongly assume that libcuda.so exists (e.g. with dlopen). // create libcuda.so -> libcuda.so.1 symlink link := fmt.Sprintf("%s::%s", "libcuda.so.1", filepath.Join(dir, "libcuda.so")) return []string{link} case d.isDriverLibrary("libGLX_nvidia.so", filename): // XXX GLVND requires this symlink for indirect GLX support. // create libGLX_indirect.so.0 -> libGLX_nvidia.so.VERSION symlink link := fmt.Sprintf("%s::%s", filename, filepath.Join(dir, "libGLX_indirect.so.0")) return []string{link} case d.isDriverLibrary("libnvidia-opticalflow.so", filename): // XXX Fix missing symlink for libnvidia-opticalflow.so. // create libnvidia-opticalflow.so -> libnvidia-opticalflow.so.1 symlink link := fmt.Sprintf("%s::%s", "libnvidia-opticalflow.so.1", filepath.Join(dir, "libnvidia-opticalflow.so")) return []string{link} } return nil } // isDriverLibrary checks whether the specified filename is a specific driver library. func (d additionalSymlinks) isDriverLibrary(libraryName string, filename string) bool { pattern := libraryName + "." + d.version match, _ := filepath.Match(pattern, filename) return match }