nvidia-container-toolkit/tools/container/toolkit/installer/installer.go
Evan Lezar 2ce23c9af3 [no-relnote] Refactor toolkit installation
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2024-10-22 14:05:03 +02:00

154 lines
3.6 KiB
Go

/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package installer
import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
)
type Installer interface {
Install(string) error
}
func New(opts ...Option) (Installer, error) {
t := &toolkitInstaller{}
for _, opt := range opts {
opt(t)
}
if t.artifactRoot == nil {
resolvedPackageType, err := resolvePackageType(t.hostRoot, t.packageType)
if err != nil {
return nil, err
}
artifactRoot, err := newArtifactRoot(resolvedPackageType)
if err != nil {
return nil, err
}
t.artifactRoot = artifactRoot
}
return t, nil
}
type toolkitInstaller struct {
artifactRoot *artifactRoot
hostRoot string
packageType string
ignoreErrors bool
}
// Install ensures that the required toolkit files are installed in the specified directory.
// The process is as follows:
func (t *toolkitInstaller) Install(destDir string) error {
var installers []Installer
libraries, err := t.collectLibraries()
if err != nil {
return fmt.Errorf("failed to collect libraries: %w", err)
}
installers = append(installers, libraries...)
executables, err := t.collectExecutables(destDir)
if err != nil {
return fmt.Errorf("failed to collect executables: %w", err)
}
installers = append(installers, executables...)
var errs error
for _, i := range installers {
errs = errors.Join(errs, i.Install(destDir))
}
return errs
}
type symlink struct {
linkname string
target string
}
func (s symlink) Install(destDir string) error {
symlinkPath := filepath.Join(destDir, s.linkname)
return installSymlink(s.target, symlinkPath)
}
//go:generate moq -rm -out file-installer_mock.go . fileInstaller
type fileInstaller interface {
installContent(io.Reader, string, os.FileMode) error
installFile(string, string) (os.FileMode, error)
installSymlink(string, string) error
}
var installSymlink = installSymlinkStub
func installSymlinkStub(target string, link string) error {
err := os.Symlink(target, link)
if err != nil {
return fmt.Errorf("error creating symlink '%v' => '%v': %v", link, target, err)
}
return nil
}
var installFile = installFileStub
func installFileStub(src string, dest string) (os.FileMode, error) {
sourceInfo, err := os.Stat(src)
if err != nil {
return 0, fmt.Errorf("error getting file info for '%v': %v", src, err)
}
source, err := os.Open(src)
if err != nil {
return 0, fmt.Errorf("error opening source: %w", err)
}
defer source.Close()
mode := sourceInfo.Mode()
if err := installContent(source, dest, mode); err != nil {
return 0, err
}
return mode, nil
}
var installContent = installContentStub
func installContentStub(content io.Reader, dest string, mode fs.FileMode) error {
destination, err := os.Create(dest)
if err != nil {
return fmt.Errorf("error creating destination: %w", err)
}
defer destination.Close()
_, err = io.Copy(destination, content)
if err != nil {
return fmt.Errorf("error copying file: %w", err)
}
err = os.Chmod(dest, mode)
if err != nil {
return fmt.Errorf("error setting mode for '%v': %v", dest, err)
}
return nil
}