nvidia-container-toolkit/tools/container/nvidia-toolkit/run.go

286 lines
6.6 KiB
Go
Raw Normal View History

package main
import (
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2"
unix "golang.org/x/sys/unix"
)
const (
runDir = "/run/nvidia"
pidFile = runDir + "/toolkit.pid"
toolkitCommand = "toolkit"
toolkitSubDir = "toolkit"
defaultToolkitArgs = ""
defaultRuntime = "docker"
defaultRuntimeArgs = ""
)
var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
var waitingForSignal = make(chan bool, 1)
var signalReceived = make(chan bool, 1)
var destinationArg string
var noDaemonFlag bool
var runtimeFlag string
var runtimeArgsFlag string
// Version defines the CLI version. This is set at build time using LD FLAGS
var Version = "development"
func main() {
// Create the top-level CLI
c := cli.NewApp()
c.Name = "nvidia-toolkit"
c.Usage = "Install the nvidia-container-toolkit for use by a given runtime"
c.UsageText = "DESTINATION [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]"
c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit"
c.Version = Version
c.Action = Run
// Setup flags for the CLI
c.Flags = []cli.Flag{
&cli.BoolFlag{
Name: "no-daemon",
Aliases: []string{"n"},
Usage: "terminate immediatly after setting up the runtime. Note that no cleanup will be performed",
Destination: &noDaemonFlag,
EnvVars: []string{"NO_DAEMON"},
},
&cli.StringFlag{
Name: "runtime",
Aliases: []string{"r"},
Usage: "the runtime to setup on this node. One of {'docker', 'crio', 'containerd'}",
Value: defaultRuntime,
Destination: &runtimeFlag,
EnvVars: []string{"RUNTIME"},
},
&cli.StringFlag{
Name: "runtime-args",
Aliases: []string{"u"},
Usage: "arguments to pass to 'docker', 'crio', or 'containerd' setup command",
Value: defaultRuntimeArgs,
Destination: &runtimeArgsFlag,
EnvVars: []string{"RUNTIME_ARGS"},
},
}
// Run the CLI
log.Infof("Starting %v", c.Name)
remainingArgs, err := ParseArgs(os.Args)
if err != nil {
log.Errorf("Error: unable to parse arguments: %v", err)
os.Exit(1)
}
if err := c.Run(remainingArgs); err != nil {
log.Errorf("error running nvidia-toolkit: %v", err)
os.Exit(1)
}
log.Infof("Completed %v", c.Name)
}
// Run runs the core logic of the CLI
func Run(c *cli.Context) error {
err := verifyFlags()
if err != nil {
return fmt.Errorf("unable to verify flags: %v", err)
}
err = initialize()
if err != nil {
return fmt.Errorf("unable to initialize: %v", err)
}
defer shutdown()
err = installToolkit()
if err != nil {
return fmt.Errorf("unable to install toolkit: %v", err)
}
err = setupRuntime()
if err != nil {
return fmt.Errorf("unable to setup runtime: %v", err)
}
if !noDaemonFlag {
err = waitForSignal()
if err != nil {
return fmt.Errorf("unable to wait for signal: %v", err)
}
err = cleanupRuntime()
if err != nil {
return fmt.Errorf("unable to cleanup runtime: %v", err)
}
}
return nil
}
// ParseArgs parses the command line arguments and returns the remaining arguments
func ParseArgs(args []string) ([]string, error) {
log.Infof("Parsing arguments")
numPositionalArgs := 2 // Includes command itself
if len(args) < numPositionalArgs {
return nil, fmt.Errorf("missing arguments")
}
for _, arg := range args {
if arg == "--help" || arg == "-h" {
return []string{args[0], arg}, nil
}
if arg == "--version" || arg == "-v" {
return []string{args[0], arg}, nil
}
}
for _, arg := range args[:numPositionalArgs] {
if strings.HasPrefix(arg, "-") {
return nil, fmt.Errorf("unexpected flag where argument should be")
}
}
for _, arg := range args[numPositionalArgs:] {
if !strings.HasPrefix(arg, "-") {
return nil, fmt.Errorf("unexpected argument where flag should be")
}
}
destinationArg = args[1]
return append([]string{args[0]}, args[numPositionalArgs:]...), nil
}
func verifyFlags() error {
log.Infof("Verifying Flags")
if _, exists := availableRuntimes[runtimeFlag]; !exists {
return fmt.Errorf("unknown runtime: %v", runtimeFlag)
}
return nil
}
func initialize() error {
log.Infof("Initializing")
f, err := os.Create(pidFile)
if err != nil {
return fmt.Errorf("unable to create pidfile: %v", err)
}
err = unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB)
if err != nil {
log.Warnf("Unable to get exclusive lock on '%v'", pidFile)
log.Warnf("This normally means an instance of the NVIDIA toolkit Container is already running, aborting")
return fmt.Errorf("unable to get flock on pidfile: %v", err)
}
_, err = f.WriteString(fmt.Sprintf("%v\n", os.Getpid()))
if err != nil {
return fmt.Errorf("unable to write PID to pidfile: %v", err)
}
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGPIPE, syscall.SIGTERM)
go func() {
<-sigs
select {
case <-waitingForSignal:
signalReceived <- true
default:
log.Infof("Signal received, exiting early")
shutdown()
os.Exit(0)
}
}()
return nil
}
func installToolkit() error {
log.Infof("Installing toolkit")
cmdline := []string{
toolkitCommand,
"install",
"--toolkit-root",
filepath.Join(destinationArg, toolkitSubDir),
}
cmd := exec.Command("sh", "-c", strings.Join(cmdline, " "))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error running %v command: %v", cmdline, err)
}
return nil
}
func setupRuntime() error {
toolkitDir := filepath.Join(destinationArg, toolkitSubDir)
log.Infof("Setting up runtime")
cmdline := fmt.Sprintf("%v setup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir)
cmd := exec.Command("sh", "-c", cmdline)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error running %v command: %v", runtimeFlag, err)
}
return nil
}
func waitForSignal() error {
log.Infof("Waiting for signal")
waitingForSignal <- true
<-signalReceived
return nil
}
func cleanupRuntime() error {
toolkitDir := filepath.Join(destinationArg, toolkitSubDir)
log.Infof("Cleaning up Runtime")
cmdline := fmt.Sprintf("%v cleanup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir)
cmd := exec.Command("sh", "-c", cmdline)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("error running %v command: %v", runtimeFlag, err)
}
return nil
}
func shutdown() {
log.Infof("Shutting Down")
err := os.Remove(pidFile)
if err != nil {
log.Warnf("Unable to remove pidfile: %v", err)
}
}