mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-02-23 12:49:53 +00:00
172 lines
4.0 KiB
Go
172 lines
4.0 KiB
Go
/*
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package e2e
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type localRunner struct{}
|
|
type remoteRunner struct {
|
|
sshKey string
|
|
sshUser string
|
|
host string
|
|
port string
|
|
}
|
|
|
|
type runnerOption func(*remoteRunner)
|
|
|
|
type Runner interface {
|
|
Run(script string) (string, string, error)
|
|
}
|
|
|
|
func WithSshKey(key string) runnerOption {
|
|
return func(r *remoteRunner) {
|
|
r.sshKey = key
|
|
}
|
|
}
|
|
|
|
func WithSshUser(user string) runnerOption {
|
|
return func(r *remoteRunner) {
|
|
r.sshUser = user
|
|
}
|
|
}
|
|
|
|
func WithHost(host string) runnerOption {
|
|
return func(r *remoteRunner) {
|
|
r.host = host
|
|
}
|
|
}
|
|
|
|
func WithPort(port string) runnerOption {
|
|
return func(r *remoteRunner) {
|
|
r.port = port
|
|
}
|
|
}
|
|
|
|
func NewRunner(opts ...runnerOption) Runner {
|
|
r := &remoteRunner{}
|
|
for _, opt := range opts {
|
|
opt(r)
|
|
}
|
|
|
|
// If the Host is empty, return a local runner
|
|
if r.host == "" {
|
|
return localRunner{}
|
|
}
|
|
|
|
// Otherwise, return a remote runner
|
|
return r
|
|
}
|
|
|
|
func (l localRunner) Run(script string) (string, string, error) {
|
|
// Create a command to run the script using bash
|
|
cmd := exec.Command("bash", "-c", script)
|
|
|
|
// Buffer to capture standard output
|
|
var stdout bytes.Buffer
|
|
cmd.Stdout = &stdout
|
|
|
|
// Buffer to capture standard error
|
|
var stderr bytes.Buffer
|
|
cmd.Stderr = &stderr
|
|
|
|
// Run the command
|
|
err := cmd.Run()
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("script execution failed: %v\nSTDOUT: %s\nSTDERR: %s", err, stdout.String(), stderr.String())
|
|
}
|
|
|
|
// Return the captured stdout and nil error
|
|
return stdout.String(), "", nil
|
|
}
|
|
|
|
func (r remoteRunner) Run(script string) (string, string, error) {
|
|
// Create a new SSH connection
|
|
client, err := connectOrDie(r.sshKey, r.sshUser, r.host, r.port)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to connect to %s: %v", r.host, err)
|
|
}
|
|
defer client.Close()
|
|
|
|
// Create a session
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to create session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// Capture stdout and stderr
|
|
var stdout, stderr bytes.Buffer
|
|
session.Stdout = &stdout
|
|
session.Stderr = &stderr
|
|
|
|
// Run the script
|
|
err = session.Run(script)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("script execution failed: %v\nSTDOUT: %s\nSTDERR: %s", err, stdout.String(), stderr.String())
|
|
}
|
|
|
|
// Return stdout as string if no errors
|
|
return stdout.String(), "", nil
|
|
}
|
|
|
|
// createSshClient creates a ssh client, and retries if it fails to connect
|
|
func connectOrDie(sshKey, sshUser, host, port string) (*ssh.Client, error) {
|
|
var client *ssh.Client
|
|
var err error
|
|
key, err := os.ReadFile(sshKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read key file: %v", err)
|
|
}
|
|
signer, err := ssh.ParsePrivateKey(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse private key: %v", err)
|
|
}
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: sshUser,
|
|
Auth: []ssh.AuthMethod{
|
|
ssh.PublicKeys(signer),
|
|
},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
}
|
|
|
|
connectionFailed := false
|
|
for i := 0; i < 20; i++ {
|
|
client, err = ssh.Dial("tcp", host+":"+port, sshConfig)
|
|
if err == nil {
|
|
return client, nil // Connection succeeded, return the client.
|
|
}
|
|
connectionFailed = true
|
|
// Sleep for a brief moment before retrying.
|
|
// You can adjust the duration based on your requirements.
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
|
|
if connectionFailed {
|
|
return nil, fmt.Errorf("failed to connect to %s after 10 retries, giving up", host)
|
|
}
|
|
|
|
return client, nil
|
|
}
|