mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-12-01 16:52:54 +00:00
109 lines
3.4 KiB
Go
109 lines
3.4 KiB
Go
|
/**
|
||
|
# Copyright 2024 NVIDIA CORPORATION
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
**/
|
||
|
|
||
|
package discover
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"path/filepath"
|
||
|
)
|
||
|
|
||
|
type additionalSymlinks struct {
|
||
|
Discover
|
||
|
version string
|
||
|
nvidiaCDIHookPath string
|
||
|
}
|
||
|
|
||
|
// WithDriverDotSoSymlinks decorates the provided discoverer.
|
||
|
// A hook is added that checks for specific driver symlinks that need to be created.
|
||
|
func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover {
|
||
|
if version == "" {
|
||
|
version = "*.*"
|
||
|
}
|
||
|
return &additionalSymlinks{
|
||
|
Discover: mounts,
|
||
|
nvidiaCDIHookPath: nvidiaCDIHookPath,
|
||
|
version: version,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Hooks returns a hook to create the additional symlinks based on the mounts.
|
||
|
func (d *additionalSymlinks) Hooks() ([]Hook, error) {
|
||
|
mounts, err := d.Discover.Mounts()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to get library mounts: %v", err)
|
||
|
}
|
||
|
hooks, err := d.Discover.Hooks()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to get hooks: %v", err)
|
||
|
}
|
||
|
|
||
|
var links []string
|
||
|
processedPaths := make(map[string]bool)
|
||
|
processedLinks := make(map[string]bool)
|
||
|
for _, mount := range mounts {
|
||
|
if processedPaths[mount.Path] {
|
||
|
continue
|
||
|
}
|
||
|
processedPaths[mount.Path] = true
|
||
|
|
||
|
for _, link := range d.getLinksForMount(mount.Path) {
|
||
|
if processedLinks[link] {
|
||
|
continue
|
||
|
}
|
||
|
processedLinks[link] = true
|
||
|
links = append(links, link)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(links) == 0 {
|
||
|
return hooks, nil
|
||
|
}
|
||
|
|
||
|
hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook)
|
||
|
return append(hooks, hook), nil
|
||
|
}
|
||
|
|
||
|
// getLinksForMount maps the path to created links if any.
|
||
|
func (d additionalSymlinks) getLinksForMount(path string) []string {
|
||
|
dir, filename := filepath.Split(path)
|
||
|
switch {
|
||
|
case d.isDriverLibrary("libcuda.so", filename):
|
||
|
// XXX Many applications wrongly assume that libcuda.so exists (e.g. with dlopen).
|
||
|
// create libcuda.so -> libcuda.so.1 symlink
|
||
|
link := fmt.Sprintf("%s::%s", "libcuda.so.1", filepath.Join(dir, "libcuda.so"))
|
||
|
return []string{link}
|
||
|
case d.isDriverLibrary("libGLX_nvidia.so", filename):
|
||
|
// XXX GLVND requires this symlink for indirect GLX support.
|
||
|
// create libGLX_indirect.so.0 -> libGLX_nvidia.so.VERSION symlink
|
||
|
link := fmt.Sprintf("%s::%s", filename, filepath.Join(dir, "libGLX_indirect.so.0"))
|
||
|
return []string{link}
|
||
|
case d.isDriverLibrary("libnvidia-opticalflow.so", filename):
|
||
|
// XXX Fix missing symlink for libnvidia-opticalflow.so.
|
||
|
// create libnvidia-opticalflow.so -> libnvidia-opticalflow.so.1 symlink
|
||
|
link := fmt.Sprintf("%s::%s", "libnvidia-opticalflow.so.1", filepath.Join(dir, "libnvidia-opticalflow.so"))
|
||
|
return []string{link}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// isDriverLibrary checks whether the specified filename is a specific driver library.
|
||
|
func (d additionalSymlinks) isDriverLibrary(libraryName string, filename string) bool {
|
||
|
pattern := libraryName + "." + d.version
|
||
|
match, _ := filepath.Match(pattern, filename)
|
||
|
return match
|
||
|
}
|