mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-04-06 05:25:01 +00:00
[no-relnote] Use urfave for nvidia-container-runtime-hook CLI
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
efb18a72ad
commit
1a83552aa6
@ -30,11 +30,11 @@ func getDefaultHookConfig() (HookConfig, error) {
|
||||
}
|
||||
|
||||
// loadConfig loads the required paths for the hook config.
|
||||
func loadConfig() (*config.Config, error) {
|
||||
func (a *app) loadConfig() (*config.Config, error) {
|
||||
var configPaths []string
|
||||
var required bool
|
||||
if len(*configflag) != 0 {
|
||||
configPaths = append(configPaths, *configflag)
|
||||
if len(a.configFile) != 0 {
|
||||
configPaths = append(configPaths, a.configFile)
|
||||
required = true
|
||||
} else {
|
||||
configPaths = append(configPaths, path.Join(driverPath, configPath), configPath)
|
||||
@ -56,8 +56,8 @@ func loadConfig() (*config.Config, error) {
|
||||
return config.GetDefault()
|
||||
}
|
||||
|
||||
func getHookConfig() (*HookConfig, error) {
|
||||
cfg, err := loadConfig()
|
||||
func (a *app) getHookConfig() (*HookConfig, error) {
|
||||
cfg, err := a.loadConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load config: %v", err)
|
||||
}
|
||||
|
@ -72,16 +72,17 @@ func TestGetHookConfig(t *testing.T) {
|
||||
if len(filename) > 0 {
|
||||
os.Remove(filename)
|
||||
}
|
||||
configflag = nil
|
||||
}()
|
||||
|
||||
a := &app{}
|
||||
|
||||
if tc.lines != nil {
|
||||
configFile, err := os.CreateTemp("", "*.toml")
|
||||
require.NoError(t, err)
|
||||
defer configFile.Close()
|
||||
|
||||
filename = configFile.Name()
|
||||
configflag = &filename
|
||||
a.configFile = filename
|
||||
|
||||
for _, line := range tc.lines {
|
||||
_, err := configFile.WriteString(fmt.Sprintf("%s\n", line))
|
||||
@ -91,7 +92,7 @@ func TestGetHookConfig(t *testing.T) {
|
||||
|
||||
var config HookConfig
|
||||
getHookConfig := func() {
|
||||
c, _ := getHookConfig()
|
||||
c, _ := a.getHookConfig()
|
||||
config = *c
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@ -13,29 +13,26 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
cli "github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
)
|
||||
|
||||
var (
|
||||
debugflag = flag.Bool("debug", false, "enable debug output")
|
||||
versionflag = flag.Bool("version", false, "enable version output")
|
||||
configflag = flag.String("config", "", "configuration file")
|
||||
)
|
||||
|
||||
func exit() {
|
||||
func (a *app) recoverIfRequired() error {
|
||||
if err := recover(); err != nil {
|
||||
if _, ok := err.(runtime.Error); ok {
|
||||
rerr, ok := err.(runtime.Error)
|
||||
if ok {
|
||||
log.Println(err)
|
||||
}
|
||||
if *debugflag {
|
||||
if a.isDebug {
|
||||
log.Printf("%s", debug.Stack())
|
||||
}
|
||||
os.Exit(1)
|
||||
return rerr
|
||||
}
|
||||
os.Exit(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCLIPath(config config.ContainerCLIConfig) string {
|
||||
@ -63,15 +60,15 @@ func getRootfsPath(config containerConfig) string {
|
||||
return rootfs
|
||||
}
|
||||
|
||||
func doPrestart() {
|
||||
var err error
|
||||
|
||||
defer exit()
|
||||
func (a *app) doPrestart() (rerr error) {
|
||||
defer func() {
|
||||
rerr = errors.Join(rerr, a.recoverIfRequired())
|
||||
}()
|
||||
log.SetFlags(0)
|
||||
|
||||
hook, err := getHookConfig()
|
||||
hook, err := a.getHookConfig()
|
||||
if err != nil || hook == nil {
|
||||
log.Panicln("error getting hook config:", err)
|
||||
return fmt.Errorf("error getting hook config: %w", err)
|
||||
}
|
||||
cli := hook.NVIDIAContainerCLIConfig
|
||||
|
||||
@ -79,11 +76,11 @@ func doPrestart() {
|
||||
nvidia := container.Nvidia
|
||||
if nvidia == nil {
|
||||
// Not a GPU container, nothing to do.
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" {
|
||||
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.")
|
||||
return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
|
||||
}
|
||||
|
||||
rootfs := getRootfsPath(container)
|
||||
@ -101,7 +98,7 @@ func doPrestart() {
|
||||
if cli.NoPivot {
|
||||
args = append(args, "--no-pivot")
|
||||
}
|
||||
if *debugflag {
|
||||
if a.isDebug {
|
||||
args = append(args, "--debug=/dev/stderr")
|
||||
} else if cli.Debug != "" {
|
||||
args = append(args, fmt.Sprintf("--debug=%s", cli.Debug))
|
||||
@ -149,45 +146,60 @@ func doPrestart() {
|
||||
|
||||
env := append(os.Environ(), cli.Environment...)
|
||||
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection?
|
||||
err = syscall.Exec(args[0], args, env)
|
||||
log.Panicln("exec failed:", err)
|
||||
return syscall.Exec(args[0], args, env)
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nCommands:\n")
|
||||
fmt.Fprintf(os.Stderr, " prestart\n run the prestart hook\n")
|
||||
fmt.Fprintf(os.Stderr, " poststart\n no-op\n")
|
||||
fmt.Fprintf(os.Stderr, " poststop\n no-op\n")
|
||||
type options struct {
|
||||
isDebug bool
|
||||
configFile string
|
||||
}
|
||||
type app struct {
|
||||
options
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
a := &app{}
|
||||
// Create the top-level CLI
|
||||
c := cli.NewApp()
|
||||
c.Name = "NVIDIA Container Runtime Hook"
|
||||
c.Version = info.GetVersionString()
|
||||
|
||||
if *versionflag {
|
||||
fmt.Printf("%v version %v\n", "NVIDIA Container Runtime Hook", info.GetVersionString())
|
||||
return
|
||||
c.Flags = []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: "debug",
|
||||
Destination: &a.isDebug,
|
||||
Usage: "Enabled debug output",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "config",
|
||||
Destination: &a.configFile,
|
||||
Usage: "The path to the configuration file to use",
|
||||
},
|
||||
}
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) == 0 {
|
||||
flag.Usage()
|
||||
os.Exit(2)
|
||||
c.Commands = []*cli.Command{
|
||||
{
|
||||
Name: "prestart",
|
||||
Usage: "run the prestart hook",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
return a.doPrestart()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "poststart",
|
||||
Aliases: []string{"poststop"},
|
||||
Usage: "no-op",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
c.DefaultCommand = "prestart"
|
||||
|
||||
switch args[0] {
|
||||
case "prestart":
|
||||
doPrestart()
|
||||
os.Exit(0)
|
||||
case "poststart":
|
||||
fallthrough
|
||||
case "poststop":
|
||||
os.Exit(0)
|
||||
default:
|
||||
flag.Usage()
|
||||
os.Exit(2)
|
||||
// Run the CLI
|
||||
err := c.Run(os.Args)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user