diff --git a/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go b/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go index 8fe92e32..39ade7f0 100644 --- a/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go +++ b/cmd/nvidia-ctk/hook/create-symlinks/create-symlinks.go @@ -24,6 +24,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -125,21 +126,18 @@ func (m command) run(c *cli.Context, cfg *config) error { created := make(map[string]bool) // candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain. for _, candidate := range candidates { - targets, err := m.Locate(candidate) + target, err := symlinks.Resolve(candidate) if err != nil { m.logger.Debugf("Skipping invalid link: %v", err) continue - } else if len(targets) != 1 { - m.logger.Debugf("Unexepected number of targets: %v", targets) - continue - } else if targets[0] == candidate { + } else if target == candidate { m.logger.Debugf("%v is not a symlink", candidate) continue } - err = m.createLink(created, cfg.hostRoot, containerRoot, targets[0], candidate) + err = m.createLink(created, cfg.hostRoot, containerRoot, target, candidate) if err != nil { - m.logger.Warnf("Failed to create link %v: %v", []string{targets[0], candidate}, err) + m.logger.Warnf("Failed to create link %v: %v", []string{target, candidate}, err) } } diff --git a/internal/discover/symlinks.go b/internal/discover/symlinks.go index 31e8e64e..a9e15d8e 100644 --- a/internal/discover/symlinks.go +++ b/internal/discover/symlinks.go @@ -21,12 +21,16 @@ import ( "path/filepath" "strings" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" "github.com/sirupsen/logrus" ) -type symlinks struct { +type symlinkHook struct { None logger *logrus.Logger + driverRoot string nvidiaCTKPath string csvFiles []string mountsFrom Discover @@ -34,7 +38,7 @@ type symlinks struct { // NewCreateSymlinksHook creates a discoverer for a hook that creates required symlinks in the container func NewCreateSymlinksHook(logger *logrus.Logger, csvFiles []string, mounts Discover, nvidiaCTKPath string) (Discover, error) { - d := symlinks{ + d := symlinkHook{ logger: logger, nvidiaCTKPath: nvidiaCTKPath, csvFiles: csvFiles, @@ -45,17 +49,17 @@ func NewCreateSymlinksHook(logger *logrus.Logger, csvFiles []string, mounts Disc } // Hooks returns a hook to create the symlinks from the required CSV files -func (d symlinks) Hooks() ([]Hook, error) { - var args []string - for _, f := range d.csvFiles { - args = append(args, "--csv-filename", f) - } - - links, err := d.getSpecificLinkArgs() +func (d symlinkHook) Hooks() ([]Hook, error) { + specificLinks, err := d.getSpecificLinks() if err != nil { return nil, fmt.Errorf("failed to determine specific links: %v", err) } - args = append(args, links...) + + csvSymlinks := d.getCSVFileSymlinks() + var args []string + for _, link := range append(csvSymlinks, specificLinks...) { + args = append(args, "--link", link) + } hook := CreateNvidiaCTKHook( d.nvidiaCTKPath, @@ -66,8 +70,8 @@ func (d symlinks) Hooks() ([]Hook, error) { return []Hook{hook}, nil } -// getSpecificLinkArgs returns the required specic links that need to be created -func (d symlinks) getSpecificLinkArgs() ([]string, error) { +// getSpecificLinks returns the required specic links that need to be created +func (d symlinkHook) getSpecificLinks() ([]string, error) { mounts, err := d.mountsFrom.Mounts() if err != nil { return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err) @@ -99,11 +103,60 @@ func (d symlinks) getSpecificLinkArgs() ([]string, error) { if linkProcessed[link] { continue } + linkProcessed[link] = true linkPath := filepath.Join(filepath.Dir(m.Path), link) - links = append(links, "--link", fmt.Sprintf("%v::%v", target, linkPath)) - linkProcessed[link] = true + links = append(links, fmt.Sprintf("%v::%v", target, linkPath)) } return links, nil } + +func (d symlinkHook) getCSVFileSymlinks() []string { + chainLocator := lookup.NewSymlinkChainLocator(d.logger, d.driverRoot) + + var candidates []string + for _, file := range d.csvFiles { + mountSpecs, err := csv.NewCSVFileParser(d.logger, file).Parse() + if err != nil { + d.logger.Debugf("Skipping CSV file %v: %v", file, err) + continue + } + + for _, ms := range mountSpecs { + if ms.Type != csv.MountSpecSym { + continue + } + targets, err := chainLocator.Locate(ms.Path) + if err != nil { + d.logger.Warnf("Failed to locate symlink %v", ms.Path) + } + candidates = append(candidates, targets...) + } + } + + var links []string + created := make(map[string]bool) + // candidates is a list of absolute paths to symlinks in a chain, or the final target of the chain. + for _, candidate := range candidates { + target, err := symlinks.Resolve(candidate) + if err != nil { + d.logger.Debugf("Skipping invalid link: %v", err) + continue + } else if target == candidate { + d.logger.Debugf("%v is not a symlink", candidate) + continue + } + + link := fmt.Sprintf("%v::%v", target, candidate) + if created[link] { + d.logger.Debugf("skipping duplicate link: %v", link) + continue + } + created[link] = true + + links = append(links, link) + } + + return links +} diff --git a/internal/ldcache/ldcache.go b/internal/ldcache/ldcache.go index 5493dc31..1fe7d074 100644 --- a/internal/ldcache/ldcache.go +++ b/internal/ldcache/ldcache.go @@ -29,6 +29,7 @@ import ( "syscall" "unsafe" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" log "github.com/sirupsen/logrus" ) @@ -288,16 +289,7 @@ func (c *ldcache) resolve(target string) (string, error) { c.logger.Debugf("checking %v", string(name)) - info, err := os.Lstat(name) - if err != nil { - return "", fmt.Errorf("failed to get file info: %v", info) - } - if info.Mode()&os.ModeSymlink == 0 { - c.logger.Debugf("Resolved regular file: %v", name) - return name, nil - } - - link, err := os.Readlink(name) + link, err := symlinks.Resolve(name) if err != nil { return "", fmt.Errorf("failed to resolve symlink: %v", err) } diff --git a/internal/lookup/symlinks.go b/internal/lookup/symlinks.go index 0d82a9e1..438e42d5 100644 --- a/internal/lookup/symlinks.go +++ b/internal/lookup/symlinks.go @@ -18,9 +18,9 @@ package lookup import ( "fmt" - "os" "path/filepath" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" "github.com/sirupsen/logrus" ) @@ -74,16 +74,9 @@ func (p symlinkChain) Locate(pattern string) ([]string, error) { } found[candidate] = true - info, err := os.Lstat(candidate) + target, err := symlinks.Resolve(candidate) if err != nil { - return nil, fmt.Errorf("failed to get file info: %v", info) - } - if info.Mode()&os.ModeSymlink == 0 { - continue - } - target, err := os.Readlink(candidate) - if err != nil { - return nil, fmt.Errorf("error checking symlink: %v", err) + return nil, fmt.Errorf("error resolving symlink: %v", err) } if !filepath.IsAbs(target) { diff --git a/internal/lookup/symlinks/symlink.go b/internal/lookup/symlinks/symlink.go new file mode 100644 index 00000000..991d47cb --- /dev/null +++ b/internal/lookup/symlinks/symlink.go @@ -0,0 +1,35 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package symlinks + +import ( + "fmt" + "os" +) + +// Resolve returns the link target of the specified filename or the filename if it is not a link. +func Resolve(filename string) (string, error) { + info, err := os.Lstat(filename) + if err != nil { + return filename, fmt.Errorf("failed to get file info: %v", info) + } + if info.Mode()&os.ModeSymlink == 0 { + return filename, nil + } + + return os.Readlink(filename) +} diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index aa97f728..84eb0c6d 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -81,6 +81,9 @@ func NewCSVModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec) tegra.WithNVIDIACTKPath(cfg.NVIDIACTKConfig.Path), tegra.WithCSVFiles(csvFiles), ) + if err != nil { + return nil, fmt.Errorf("failed to construct discoverer: %v", err) + } discoverModifier, err := NewModifierFromDiscoverer(logger, d) if err != nil {