mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-29 16:02:10 +00:00
2ce23c9af3
Signed-off-by: Evan Lezar <elezar@nvidia.com>
154 lines
3.6 KiB
Go
154 lines
3.6 KiB
Go
/**
|
|
# Copyright 2024 NVIDIA CORPORATION
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
**/
|
|
|
|
package installer
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"os"
|
|
"path/filepath"
|
|
)
|
|
|
|
type Installer interface {
|
|
Install(string) error
|
|
}
|
|
|
|
func New(opts ...Option) (Installer, error) {
|
|
t := &toolkitInstaller{}
|
|
for _, opt := range opts {
|
|
opt(t)
|
|
}
|
|
|
|
if t.artifactRoot == nil {
|
|
resolvedPackageType, err := resolvePackageType(t.hostRoot, t.packageType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
artifactRoot, err := newArtifactRoot(resolvedPackageType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.artifactRoot = artifactRoot
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
type toolkitInstaller struct {
|
|
artifactRoot *artifactRoot
|
|
|
|
hostRoot string
|
|
packageType string
|
|
|
|
ignoreErrors bool
|
|
}
|
|
|
|
// Install ensures that the required toolkit files are installed in the specified directory.
|
|
// The process is as follows:
|
|
func (t *toolkitInstaller) Install(destDir string) error {
|
|
var installers []Installer
|
|
|
|
libraries, err := t.collectLibraries()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to collect libraries: %w", err)
|
|
}
|
|
installers = append(installers, libraries...)
|
|
|
|
executables, err := t.collectExecutables(destDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to collect executables: %w", err)
|
|
}
|
|
installers = append(installers, executables...)
|
|
|
|
var errs error
|
|
for _, i := range installers {
|
|
errs = errors.Join(errs, i.Install(destDir))
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
type symlink struct {
|
|
linkname string
|
|
target string
|
|
}
|
|
|
|
func (s symlink) Install(destDir string) error {
|
|
symlinkPath := filepath.Join(destDir, s.linkname)
|
|
return installSymlink(s.target, symlinkPath)
|
|
}
|
|
|
|
//go:generate moq -rm -out file-installer_mock.go . fileInstaller
|
|
type fileInstaller interface {
|
|
installContent(io.Reader, string, os.FileMode) error
|
|
installFile(string, string) (os.FileMode, error)
|
|
installSymlink(string, string) error
|
|
}
|
|
|
|
var installSymlink = installSymlinkStub
|
|
|
|
func installSymlinkStub(target string, link string) error {
|
|
err := os.Symlink(target, link)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating symlink '%v' => '%v': %v", link, target, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var installFile = installFileStub
|
|
|
|
func installFileStub(src string, dest string) (os.FileMode, error) {
|
|
sourceInfo, err := os.Stat(src)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error getting file info for '%v': %v", src, err)
|
|
}
|
|
|
|
source, err := os.Open(src)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error opening source: %w", err)
|
|
}
|
|
defer source.Close()
|
|
|
|
mode := sourceInfo.Mode()
|
|
if err := installContent(source, dest, mode); err != nil {
|
|
return 0, err
|
|
}
|
|
return mode, nil
|
|
}
|
|
|
|
var installContent = installContentStub
|
|
|
|
func installContentStub(content io.Reader, dest string, mode fs.FileMode) error {
|
|
destination, err := os.Create(dest)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating destination: %w", err)
|
|
}
|
|
defer destination.Close()
|
|
|
|
_, err = io.Copy(destination, content)
|
|
if err != nil {
|
|
return fmt.Errorf("error copying file: %w", err)
|
|
}
|
|
err = os.Chmod(dest, mode)
|
|
if err != nil {
|
|
return fmt.Errorf("error setting mode for '%v': %v", dest, err)
|
|
}
|
|
return nil
|
|
}
|