/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package e2e import ( "bytes" "fmt" "os" "os/exec" "time" "golang.org/x/crypto/ssh" ) type localRunner struct{} type remoteRunner struct { sshKey string sshUser string host string port string } type runnerOption func(*remoteRunner) type Runner interface { Run(script string) (string, string, error) } func WithSshKey(key string) runnerOption { return func(r *remoteRunner) { r.sshKey = key } } func WithSshUser(user string) runnerOption { return func(r *remoteRunner) { r.sshUser = user } } func WithHost(host string) runnerOption { return func(r *remoteRunner) { r.host = host } } func WithPort(port string) runnerOption { return func(r *remoteRunner) { r.port = port } } func NewRunner(opts ...runnerOption) Runner { r := &remoteRunner{} for _, opt := range opts { opt(r) } // If the Host is empty, return a local runner if r.host == "" { return localRunner{} } // Otherwise, return a remote runner return r } func (l localRunner) Run(script string) (string, string, error) { // Create a command to run the script using bash cmd := exec.Command("bash", "-c", script) // Buffer to capture standard output var stdout bytes.Buffer cmd.Stdout = &stdout // Buffer to capture standard error var stderr bytes.Buffer cmd.Stderr = &stderr // Run the command err := cmd.Run() if err != nil { return "", "", fmt.Errorf("script execution failed: %v\nSTDOUT: %s\nSTDERR: %s", err, stdout.String(), stderr.String()) } // Return the captured stdout and nil error return stdout.String(), "", nil } func (r remoteRunner) Run(script string) (string, string, error) { // Create a new SSH connection client, err := connectOrDie(r.sshKey, r.sshUser, r.host, r.port) if err != nil { return "", "", fmt.Errorf("failed to connect to %s: %v", r.host, err) } defer client.Close() // Create a session session, err := client.NewSession() if err != nil { return "", "", fmt.Errorf("failed to create session: %v", err) } defer session.Close() // Capture stdout and stderr var stdout, stderr bytes.Buffer session.Stdout = &stdout session.Stderr = &stderr // Run the script err = session.Run(script) if err != nil { return "", "", fmt.Errorf("script execution failed: %v\nSTDOUT: %s\nSTDERR: %s", err, stdout.String(), stderr.String()) } // Return stdout as string if no errors return stdout.String(), "", nil } // createSshClient creates a ssh client, and retries if it fails to connect func connectOrDie(sshKey, sshUser, host, port string) (*ssh.Client, error) { var client *ssh.Client var err error key, err := os.ReadFile(sshKey) if err != nil { return nil, fmt.Errorf("failed to read key file: %v", err) } signer, err := ssh.ParsePrivateKey(key) if err != nil { return nil, fmt.Errorf("failed to parse private key: %v", err) } sshConfig := &ssh.ClientConfig{ User: sshUser, Auth: []ssh.AuthMethod{ ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } connectionFailed := false for i := 0; i < 20; i++ { client, err = ssh.Dial("tcp", host+":"+port, sshConfig) if err == nil { return client, nil // Connection succeeded, return the client. } connectionFailed = true // Sleep for a brief moment before retrying. // You can adjust the duration based on your requirements. time.Sleep(1 * time.Second) } if connectionFailed { return nil, fmt.Errorf("failed to connect to %s after 10 retries, giving up", host) } return client, nil }