This commit is contained in:
Evan Lezar 2024-11-05 02:04:54 -08:00 committed by GitHub
commit 482331e551
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 122 additions and 109 deletions

View File

@ -30,11 +30,11 @@ func getDefaultHookConfig() (HookConfig, error) {
} }
// loadConfig loads the required paths for the hook config. // loadConfig loads the required paths for the hook config.
func loadConfig() (*config.Config, error) { func (a *app) loadConfig() (*config.Config, error) {
var configPaths []string var configPaths []string
var required bool var required bool
if len(*configflag) != 0 { if len(a.configFile) != 0 {
configPaths = append(configPaths, *configflag) configPaths = append(configPaths, a.configFile)
required = true required = true
} else { } else {
configPaths = append(configPaths, path.Join(driverPath, configPath), configPath) configPaths = append(configPaths, path.Join(driverPath, configPath), configPath)
@ -56,8 +56,8 @@ func loadConfig() (*config.Config, error) {
return config.GetDefault() return config.GetDefault()
} }
func getHookConfig() (*HookConfig, error) { func (a *app) getHookConfig() (*HookConfig, error) {
cfg, err := loadConfig() cfg, err := a.loadConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err) return nil, fmt.Errorf("failed to load config: %v", err)
} }

View File

@ -72,16 +72,17 @@ func TestGetHookConfig(t *testing.T) {
if len(filename) > 0 { if len(filename) > 0 {
os.Remove(filename) os.Remove(filename)
} }
configflag = nil
}() }()
a := &app{}
if tc.lines != nil { if tc.lines != nil {
configFile, err := os.CreateTemp("", "*.toml") configFile, err := os.CreateTemp("", "*.toml")
require.NoError(t, err) require.NoError(t, err)
defer configFile.Close() defer configFile.Close()
filename = configFile.Name() filename = configFile.Name()
configflag = &filename a.configFile = filename
for _, line := range tc.lines { for _, line := range tc.lines {
_, err := configFile.WriteString(fmt.Sprintf("%s\n", line)) _, err := configFile.WriteString(fmt.Sprintf("%s\n", line))
@ -91,7 +92,7 @@ func TestGetHookConfig(t *testing.T) {
var config HookConfig var config HookConfig
getHookConfig := func() { getHookConfig := func() {
c, _ := getHookConfig() c, _ := a.getHookConfig()
config = *c config = *c
} }

View File

@ -1,7 +1,7 @@
package main package main
import ( import (
"flag" "errors"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -13,29 +13,26 @@ import (
"strings" "strings"
"syscall" "syscall"
cli "github.com/urfave/cli/v2"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
) )
var ( func (a *app) recoverIfRequired() error {
debugflag = flag.Bool("debug", false, "enable debug output")
versionflag = flag.Bool("version", false, "enable version output")
configflag = flag.String("config", "", "configuration file")
)
func exit() {
if err := recover(); err != nil { if err := recover(); err != nil {
if _, ok := err.(runtime.Error); ok { rerr, ok := err.(runtime.Error)
if ok {
log.Println(err) log.Println(err)
} }
if *debugflag { if a.isDebug {
log.Printf("%s", debug.Stack()) log.Printf("%s", debug.Stack())
} }
os.Exit(1) return rerr
} }
os.Exit(0) return nil
} }
func getCLIPath(config config.ContainerCLIConfig) string { func getCLIPath(config config.ContainerCLIConfig) string {
@ -63,15 +60,15 @@ func getRootfsPath(config containerConfig) string {
return rootfs return rootfs
} }
func doPrestart() { func (a *app) doPrestart() (rerr error) {
var err error defer func() {
rerr = errors.Join(rerr, a.recoverIfRequired())
defer exit() }()
log.SetFlags(0) log.SetFlags(0)
hook, err := getHookConfig() hook, err := a.getHookConfig()
if err != nil || hook == nil { if err != nil || hook == nil {
log.Panicln("error getting hook config:", err) return fmt.Errorf("error getting hook config: %w", err)
} }
cli := hook.NVIDIAContainerCLIConfig cli := hook.NVIDIAContainerCLIConfig
@ -79,11 +76,11 @@ func doPrestart() {
nvidia := container.Nvidia nvidia := container.Nvidia
if nvidia == nil { if nvidia == nil {
// Not a GPU container, nothing to do. // Not a GPU container, nothing to do.
return return nil
} }
if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" { if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" {
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
} }
rootfs := getRootfsPath(container) rootfs := getRootfsPath(container)
@ -101,7 +98,7 @@ func doPrestart() {
if cli.NoPivot { if cli.NoPivot {
args = append(args, "--no-pivot") args = append(args, "--no-pivot")
} }
if *debugflag { if a.isDebug {
args = append(args, "--debug=/dev/stderr") args = append(args, "--debug=/dev/stderr")
} else if cli.Debug != "" { } else if cli.Debug != "" {
args = append(args, fmt.Sprintf("--debug=%s", cli.Debug)) args = append(args, fmt.Sprintf("--debug=%s", cli.Debug))
@ -149,45 +146,61 @@ func doPrestart() {
env := append(os.Environ(), cli.Environment...) env := append(os.Environ(), cli.Environment...)
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection? //nolint:gosec // TODO: Can we harden this so that there is less risk of command injection?
err = syscall.Exec(args[0], args, env) return syscall.Exec(args[0], args, env)
log.Panicln("exec failed:", err)
} }
func usage() { type options struct {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) isDebug bool
flag.PrintDefaults() configFile string
fmt.Fprintf(os.Stderr, "\nCommands:\n") }
fmt.Fprintf(os.Stderr, " prestart\n run the prestart hook\n") type app struct {
fmt.Fprintf(os.Stderr, " poststart\n no-op\n") options
fmt.Fprintf(os.Stderr, " poststop\n no-op\n")
} }
func main() { func main() {
flag.Usage = usage a := &app{}
flag.Parse() // Create the top-level CLI
c := cli.NewApp()
c.Name = "NVIDIA Container Runtime Hook"
c.Version = info.GetVersionString()
if *versionflag { c.Flags = []cli.Flag{
fmt.Printf("%v version %v\n", "NVIDIA Container Runtime Hook", info.GetVersionString()) &cli.BoolFlag{
return Name: "debug",
Destination: &a.isDebug,
Usage: "Enabled debug output",
},
&cli.StringFlag{
Name: "config",
Destination: &a.configFile,
Usage: "The path to the configuration file to use",
EnvVars: []string{config.FilePathOverrideEnvVar},
},
} }
args := flag.Args() c.Commands = []*cli.Command{
if len(args) == 0 { {
flag.Usage() Name: "prestart",
os.Exit(2) Usage: "run the prestart hook",
Action: func(ctx *cli.Context) error {
return a.doPrestart()
},
},
{
Name: "poststart",
Aliases: []string{"poststop"},
Usage: "no-op",
Action: func(ctx *cli.Context) error {
return nil
},
},
} }
c.DefaultCommand = "prestart"
switch args[0] { // Run the CLI
case "prestart": err := c.Run(os.Args)
doPrestart() if err != nil {
os.Exit(0) os.Exit(1)
case "poststart":
fallthrough
case "poststop":
os.Exit(0)
default:
flag.Usage()
os.Exit(2)
} }
} }

View File

@ -30,8 +30,10 @@ import (
) )
const ( const (
configOverride = "XDG_CONFIG_HOME" FilePathOverrideEnvVar = "NVCTK_CONFIG_FILE_PATH"
configFilePath = "nvidia-container-runtime/config.toml" RelativeFilePath = "nvidia-container-runtime/config.toml"
configRootOverride = "XDG_CONFIG_HOME"
nvidiaCTKExecutable = "nvidia-ctk" nvidiaCTKExecutable = "nvidia-ctk"
nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk" nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk"
@ -71,11 +73,15 @@ type Config struct {
// GetConfigFilePath returns the path to the config file for the configured system // GetConfigFilePath returns the path to the config file for the configured system
func GetConfigFilePath() string { func GetConfigFilePath() string {
if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 { if configFilePathOverride := os.Getenv(FilePathOverrideEnvVar); configFilePathOverride != "" {
return filepath.Join(XDGConfigDir, configFilePath) return configFilePathOverride
}
configRoot := "/etc"
if XDGConfigDir := os.Getenv(configRootOverride); len(XDGConfigDir) != 0 {
configRoot = XDGConfigDir
} }
return filepath.Join("/etc", configFilePath) return filepath.Join(configRoot, RelativeFilePath)
} }
// GetConfig sets up the config struct. Values are read from a toml file // GetConfig sets up the config struct. Values are read from a toml file

View File

@ -27,9 +27,26 @@ import (
func TestGetConfigWithCustomConfig(t *testing.T) { func TestGetConfigWithCustomConfig(t *testing.T) {
testDir := t.TempDir() testDir := t.TempDir()
t.Setenv(configOverride, testDir) t.Setenv(configRootOverride, testDir)
filename := filepath.Join(testDir, configFilePath) filename := filepath.Join(testDir, RelativeFilePath)
// By default debug is disabled
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")
require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
require.NoError(t, os.WriteFile(filename, contents, 0600))
cfg, err := GetConfig()
require.NoError(t, err)
require.Equal(t, "/nvidia-container-toolkit.log", cfg.NVIDIAContainerRuntimeConfig.DebugFilePath)
}
func TestGetConfigWithConfigFilePathOverride(t *testing.T) {
testDir := t.TempDir()
filename := filepath.Join(testDir, RelativeFilePath)
t.Setenv(FilePathOverrideEnvVar, filename)
// By default debug is disabled // By default debug is disabled
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"") contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")

View File

@ -37,7 +37,6 @@ type executable struct {
target executableTarget target executableTarget
env map[string]string env map[string]string
preLines []string preLines []string
argLines []string
} }
// install installs an executable component of the NVIDIA container toolkit. The source executable // install installs an executable component of the NVIDIA container toolkit. The source executable
@ -128,10 +127,6 @@ func (e executable) writeWrapperTo(wrapper io.Writer, destFolder string, dotfile
// Add the call to the target executable // Add the call to the target executable
fmt.Fprintf(wrapper, "%s \\\n", dotfileName) fmt.Fprintf(wrapper, "%s \\\n", dotfileName)
// Insert additional lines in the `arg` list
for _, line := range e.argLines {
fmt.Fprintf(wrapper, "\t%s \\\n", r.apply(line))
}
// Add the script arguments "$@" // Add the script arguments "$@"
fmt.Fprintln(wrapper, "\t\"$@\"") fmt.Fprintln(wrapper, "\t\"$@\"")

View File

@ -76,23 +76,6 @@ func TestWrapper(t *testing.T) {
"", "",
}, },
}, },
{
e: executable{
argLines: []string{
"argline1",
"argline2",
},
},
expectedLines: []string{
shebang,
"PATH=/dest/folder:$PATH \\",
"source.real \\",
"\targline1 \\",
"\targline2 \\",
"\t\"$@\"",
"",
},
},
} }
for i, tc := range testCases { for i, tc := range testCases {

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator" "github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator"
) )
@ -29,10 +30,10 @@ const (
// installContainerRuntimes sets up the NVIDIA container runtimes, copying the executables // installContainerRuntimes sets up the NVIDIA container runtimes, copying the executables
// and implementing the required wrapper // and implementing the required wrapper
func installContainerRuntimes(toolkitDir string, driverRoot string) error { func installContainerRuntimes(toolkitDir string, configFilePath string) error {
runtimes := operator.GetRuntimes() runtimes := operator.GetRuntimes()
for _, runtime := range runtimes { for _, runtime := range runtimes {
r := newNvidiaContainerRuntimeInstaller(runtime.Path) r := newNvidiaContainerRuntimeInstaller(runtime.Path, configFilePath)
_, err := r.install(toolkitDir) _, err := r.install(toolkitDir)
if err != nil { if err != nil {
@ -46,17 +47,17 @@ func installContainerRuntimes(toolkitDir string, driverRoot string) error {
// This installer will copy the specified source executable to the toolkit directory. // This installer will copy the specified source executable to the toolkit directory.
// The executable is copied to a file with the same name as the source, but with a ".real" suffix and a wrapper is // The executable is copied to a file with the same name as the source, but with a ".real" suffix and a wrapper is
// created to allow for the configuration of the runtime environment. // created to allow for the configuration of the runtime environment.
func newNvidiaContainerRuntimeInstaller(source string) *executable { func newNvidiaContainerRuntimeInstaller(source string, configFilePath string) *executable {
wrapperName := filepath.Base(source) wrapperName := filepath.Base(source)
dotfileName := wrapperName + ".real" dotfileName := wrapperName + ".real"
target := executableTarget{ target := executableTarget{
dotfileName: dotfileName, dotfileName: dotfileName,
wrapperName: wrapperName, wrapperName: wrapperName,
} }
return newRuntimeInstaller(source, target, nil) return newRuntimeInstaller(source, target, configFilePath, nil)
} }
func newRuntimeInstaller(source string, target executableTarget, env map[string]string) *executable { func newRuntimeInstaller(source string, target executableTarget, configFilePath string, env map[string]string) *executable {
preLines := []string{ preLines := []string{
"", "",
"cat /proc/modules | grep -e \"^nvidia \" >/dev/null 2>&1", "cat /proc/modules | grep -e \"^nvidia \" >/dev/null 2>&1",
@ -68,7 +69,7 @@ func newRuntimeInstaller(source string, target executableTarget, env map[string]
} }
runtimeEnv := make(map[string]string) runtimeEnv := make(map[string]string)
runtimeEnv["XDG_CONFIG_HOME"] = filepath.Join(destDirPattern, ".config") runtimeEnv[config.FilePathOverrideEnvVar] = configFilePath
for k, v := range env { for k, v := range env {
runtimeEnv[k] = v runtimeEnv[k] = v
} }

View File

@ -25,7 +25,7 @@ import (
) )
func TestNvidiaContainerRuntimeInstallerWrapper(t *testing.T) { func TestNvidiaContainerRuntimeInstallerWrapper(t *testing.T) {
r := newNvidiaContainerRuntimeInstaller(nvidiaContainerRuntimeSource) r := newNvidiaContainerRuntimeInstaller(nvidiaContainerRuntimeSource, "/config/file/path/config.toml")
const shebang = "#! /bin/sh" const shebang = "#! /bin/sh"
const destFolder = "/dest/folder" const destFolder = "/dest/folder"
@ -45,8 +45,8 @@ func TestNvidiaContainerRuntimeInstallerWrapper(t *testing.T) {
" exec runc \"$@\"", " exec runc \"$@\"",
"fi", "fi",
"", "",
"NVCTK_CONFIG_FILE_PATH=/config/file/path/config.toml \\",
"PATH=/dest/folder:$PATH \\", "PATH=/dest/folder:$PATH \\",
"XDG_CONFIG_HOME=/dest/folder/.config \\",
"source.real \\", "source.real \\",
"\t\"$@\"", "\t\"$@\"",
"", "",

View File

@ -297,10 +297,9 @@ func Install(cli *cli.Context, opts *Options, toolkitRoot string) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error removing toolkit directory: %v", err)) log.Errorf("Ignoring error: %v", fmt.Errorf("error removing toolkit directory: %v", err))
} }
toolkitConfigDir := filepath.Join(toolkitRoot, ".config", "nvidia-container-runtime") toolkitConfigFilePath := filepath.Join(toolkitRoot, ".config", config.RelativeFilePath)
toolkitConfigPath := filepath.Join(toolkitConfigDir, configFilename)
err = createDirectories(toolkitRoot, toolkitConfigDir) err = createDirectories(toolkitRoot, filepath.Dir(toolkitConfigFilePath))
if err != nil && !opts.ignoreErrors { if err != nil && !opts.ignoreErrors {
return fmt.Errorf("could not create required directories: %v", err) return fmt.Errorf("could not create required directories: %v", err)
} else if err != nil { } else if err != nil {
@ -314,7 +313,7 @@ func Install(cli *cli.Context, opts *Options, toolkitRoot string) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container library: %v", err)) log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container library: %v", err))
} }
err = installContainerRuntimes(toolkitRoot, opts.DriverRoot) err = installContainerRuntimes(toolkitRoot, toolkitConfigFilePath)
if err != nil && !opts.ignoreErrors { if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container runtime: %v", err) return fmt.Errorf("error installing NVIDIA container runtime: %v", err)
} else if err != nil { } else if err != nil {
@ -328,7 +327,7 @@ func Install(cli *cli.Context, opts *Options, toolkitRoot string) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container CLI: %v", err)) log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA container CLI: %v", err))
} }
nvidiaContainerRuntimeHookPath, err := installRuntimeHook(toolkitRoot, toolkitConfigPath) nvidiaContainerRuntimeHookPath, err := installRuntimeHook(toolkitRoot, toolkitConfigFilePath)
if err != nil && !opts.ignoreErrors { if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container runtime hook: %v", err) return fmt.Errorf("error installing NVIDIA container runtime hook: %v", err)
} else if err != nil { } else if err != nil {
@ -349,7 +348,7 @@ func Install(cli *cli.Context, opts *Options, toolkitRoot string) error {
log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA Container CDI Hook CLI: %v", err)) log.Errorf("Ignoring error: %v", fmt.Errorf("error installing NVIDIA Container CDI Hook CLI: %v", err))
} }
err = installToolkitConfig(cli, toolkitConfigPath, nvidiaContainerCliExecutable, nvidiaCTKPath, nvidiaContainerRuntimeHookPath, opts) err = installToolkitConfig(cli, toolkitConfigFilePath, nvidiaContainerCliExecutable, nvidiaCTKPath, nvidiaContainerRuntimeHookPath, opts)
if err != nil && !opts.ignoreErrors { if err != nil && !opts.ignoreErrors {
return fmt.Errorf("error installing NVIDIA container toolkit config: %v", err) return fmt.Errorf("error installing NVIDIA container toolkit config: %v", err)
} else if err != nil { } else if err != nil {
@ -423,8 +422,8 @@ func installLibrary(libName string, toolkitRoot string) error {
// installToolkitConfig installs the config file for the NVIDIA container toolkit ensuring // installToolkitConfig installs the config file for the NVIDIA container toolkit ensuring
// that the settings are updated to match the desired install and nvidia driver directories. // that the settings are updated to match the desired install and nvidia driver directories.
func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContainerCliExecutablePath string, nvidiaCTKPath string, nvidaContainerRuntimeHookPath string, opts *Options) error { func installToolkitConfig(c *cli.Context, toolkitConfigFilePath string, nvidiaContainerCliExecutablePath string, nvidiaCTKPath string, nvidaContainerRuntimeHookPath string, opts *Options) error {
log.Infof("Installing NVIDIA container toolkit config '%v'", toolkitConfigPath) log.Infof("Installing NVIDIA container toolkit config '%v'", toolkitConfigFilePath)
cfg, err := config.New( cfg, err := config.New(
config.WithConfigFile(nvidiaContainerToolkitConfigSource), config.WithConfigFile(nvidiaContainerToolkitConfigSource),
@ -433,7 +432,7 @@ func installToolkitConfig(c *cli.Context, toolkitConfigPath string, nvidiaContai
return fmt.Errorf("could not open source config file: %v", err) return fmt.Errorf("could not open source config file: %v", err)
} }
targetConfig, err := os.Create(toolkitConfigPath) targetConfig, err := os.Create(toolkitConfigFilePath)
if err != nil { if err != nil {
return fmt.Errorf("could not create target config file: %v", err) return fmt.Errorf("could not create target config file: %v", err)
} }
@ -579,17 +578,15 @@ func installContainerCLI(toolkitRoot string) (string, error) {
func installRuntimeHook(toolkitRoot string, configFilePath string) (string, error) { func installRuntimeHook(toolkitRoot string, configFilePath string) (string, error) {
log.Infof("Installing NVIDIA container runtime hook from '%v'", nvidiaContainerRuntimeHookSource) log.Infof("Installing NVIDIA container runtime hook from '%v'", nvidiaContainerRuntimeHookSource)
argLines := []string{
fmt.Sprintf("-config \"%s\"", configFilePath),
}
e := executable{ e := executable{
source: nvidiaContainerRuntimeHookSource, source: nvidiaContainerRuntimeHookSource,
target: executableTarget{ target: executableTarget{
dotfileName: "nvidia-container-runtime-hook.real", dotfileName: "nvidia-container-runtime-hook.real",
wrapperName: "nvidia-container-runtime-hook", wrapperName: "nvidia-container-runtime-hook",
}, },
argLines: argLines, env: map[string]string{
config.FilePathOverrideEnvVar: configFilePath,
},
} }
installedPath, err := e.install(toolkitRoot) installedPath, err := e.install(toolkitRoot)