package main import ( "fmt" "os" "os/exec" "os/signal" "path/filepath" "strings" "syscall" log "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" unix "golang.org/x/sys/unix" ) const ( runDir = "/run/nvidia" pidFile = runDir + "/toolkit.pid" toolkitCommand = "toolkit" toolkitSubDir = "toolkit" defaultToolkitArgs = "" defaultRuntime = "docker" defaultRuntimeArgs = "" ) var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}} var waitingForSignal = make(chan bool, 1) var signalReceived = make(chan bool, 1) var destinationArg string var noDaemonFlag bool var runtimeFlag string var runtimeArgsFlag string // Version defines the CLI version. This is set at build time using LD FLAGS var Version = "development" func main() { // Create the top-level CLI c := cli.NewApp() c.Name = "nvidia-toolkit" c.Usage = "Install the nvidia-container-toolkit for use by a given runtime" c.UsageText = "DESTINATION [-n | --no-daemon] [-r | --runtime] [-u | --runtime-args]" c.Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\nIt will be installed at ${DESTINATION}/toolkit" c.Version = Version c.Action = Run // Setup flags for the CLI c.Flags = []cli.Flag{ &cli.BoolFlag{ Name: "no-daemon", Aliases: []string{"n"}, Usage: "terminate immediatly after setting up the runtime. Note that no cleanup will be performed", Destination: &noDaemonFlag, EnvVars: []string{"NO_DAEMON"}, }, &cli.StringFlag{ Name: "runtime", Aliases: []string{"r"}, Usage: "the runtime to setup on this node. One of {'docker', 'crio', 'containerd'}", Value: defaultRuntime, Destination: &runtimeFlag, EnvVars: []string{"RUNTIME"}, }, &cli.StringFlag{ Name: "runtime-args", Aliases: []string{"u"}, Usage: "arguments to pass to 'docker', 'crio', or 'containerd' setup command", Value: defaultRuntimeArgs, Destination: &runtimeArgsFlag, EnvVars: []string{"RUNTIME_ARGS"}, }, } // Run the CLI log.Infof("Starting %v", c.Name) remainingArgs, err := ParseArgs(os.Args) if err != nil { log.Errorf("Error: unable to parse arguments: %v", err) os.Exit(1) } if err := c.Run(remainingArgs); err != nil { log.Errorf("error running nvidia-toolkit: %v", err) os.Exit(1) } log.Infof("Completed %v", c.Name) } // Run runs the core logic of the CLI func Run(c *cli.Context) error { err := verifyFlags() if err != nil { return fmt.Errorf("unable to verify flags: %v", err) } err = initialize() if err != nil { return fmt.Errorf("unable to initialize: %v", err) } defer shutdown() err = installToolkit() if err != nil { return fmt.Errorf("unable to install toolkit: %v", err) } err = setupRuntime() if err != nil { return fmt.Errorf("unable to setup runtime: %v", err) } if !noDaemonFlag { err = waitForSignal() if err != nil { return fmt.Errorf("unable to wait for signal: %v", err) } err = cleanupRuntime() if err != nil { return fmt.Errorf("unable to cleanup runtime: %v", err) } } return nil } // ParseArgs parses the command line arguments and returns the remaining arguments func ParseArgs(args []string) ([]string, error) { log.Infof("Parsing arguments") numPositionalArgs := 2 // Includes command itself if len(args) < numPositionalArgs { return nil, fmt.Errorf("missing arguments") } for _, arg := range args { if arg == "--help" || arg == "-h" { return []string{args[0], arg}, nil } if arg == "--version" || arg == "-v" { return []string{args[0], arg}, nil } } for _, arg := range args[:numPositionalArgs] { if strings.HasPrefix(arg, "-") { return nil, fmt.Errorf("unexpected flag where argument should be") } } for _, arg := range args[numPositionalArgs:] { if !strings.HasPrefix(arg, "-") { return nil, fmt.Errorf("unexpected argument where flag should be") } } destinationArg = args[1] return append([]string{args[0]}, args[numPositionalArgs:]...), nil } func verifyFlags() error { log.Infof("Verifying Flags") if _, exists := availableRuntimes[runtimeFlag]; !exists { return fmt.Errorf("unknown runtime: %v", runtimeFlag) } return nil } func initialize() error { log.Infof("Initializing") f, err := os.Create(pidFile) if err != nil { return fmt.Errorf("unable to create pidfile: %v", err) } err = unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB) if err != nil { log.Warnf("Unable to get exclusive lock on '%v'", pidFile) log.Warnf("This normally means an instance of the NVIDIA toolkit Container is already running, aborting") return fmt.Errorf("unable to get flock on pidfile: %v", err) } _, err = f.WriteString(fmt.Sprintf("%v\n", os.Getpid())) if err != nil { return fmt.Errorf("unable to write PID to pidfile: %v", err) } sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGPIPE, syscall.SIGTERM) go func() { <-sigs select { case <-waitingForSignal: signalReceived <- true default: log.Infof("Signal received, exiting early") shutdown() os.Exit(0) } }() return nil } func installToolkit() error { log.Infof("Installing toolkit") cmdline := []string{ toolkitCommand, "install", "--toolkit-root", filepath.Join(destinationArg, toolkitSubDir), } cmd := exec.Command("sh", "-c", strings.Join(cmdline, " ")) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { return fmt.Errorf("error running %v command: %v", cmdline, err) } return nil } func setupRuntime() error { toolkitDir := filepath.Join(destinationArg, toolkitSubDir) log.Infof("Setting up runtime") cmdline := fmt.Sprintf("%v setup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) cmd := exec.Command("sh", "-c", cmdline) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { return fmt.Errorf("error running %v command: %v", runtimeFlag, err) } return nil } func waitForSignal() error { log.Infof("Waiting for signal") waitingForSignal <- true <-signalReceived return nil } func cleanupRuntime() error { toolkitDir := filepath.Join(destinationArg, toolkitSubDir) log.Infof("Cleaning up Runtime") cmdline := fmt.Sprintf("%v cleanup %v %v\n", runtimeFlag, runtimeArgsFlag, toolkitDir) cmd := exec.Command("sh", "-c", cmdline) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() if err != nil { return fmt.Errorf("error running %v command: %v", runtimeFlag, err) } return nil } func shutdown() { log.Infof("Shutting Down") err := os.Remove(pidFile) if err != nil { log.Warnf("Unable to remove pidfile: %v", err) } }