From bf8c3bab72f5bf9bc2f1cfa8c832df1888fa8a5c Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 10 Mar 2022 13:37:08 +0200 Subject: [PATCH] Add test package with GetModuleRoot and PrependToPath function Signed-off-by: Evan Lezar --- cmd/nvidia-container-runtime/main_test.go | 32 ++------------ internal/oci/spec_test.go | 23 +--------- internal/test/test.go | 52 +++++++++++++++++++++++ 3 files changed, 57 insertions(+), 50 deletions(-) create mode 100644 internal/test/test.go diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index 9ef83b9f..0aac45ab 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -3,15 +3,14 @@ package main import ( "bytes" "encoding/json" - "fmt" "io/ioutil" "os" "os/exec" "path/filepath" - "runtime" "strings" "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/test" "github.com/opencontainers/runtime-spec/specs-go" "github.com/stretchr/testify/require" ) @@ -35,7 +34,7 @@ func TestMain(m *testing.M) { // TEST SETUP // Determine the module root and the test binary path var err error - moduleRoot, err := getModuleRoot() + moduleRoot, err := test.GetModuleRoot() if err != nil { logger.Fatalf("error in test setup: could not get module root: %v", err) } @@ -43,7 +42,7 @@ func TestMain(m *testing.M) { testInputPath := filepath.Join(moduleRoot, "test", "input") // Set the environment variables for the test - os.Setenv("PATH", prependToPath(testBinPath, moduleRoot)) + os.Setenv("PATH", test.PrependToPath(testBinPath, moduleRoot)) os.Setenv("XDG_CONFIG_HOME", testInputPath) // Confirm that the environment is configured correctly @@ -71,31 +70,6 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } -func getModuleRoot() (string, error) { - _, filename, _, _ := runtime.Caller(0) - - return hasGoMod(filename) -} - -func hasGoMod(dir string) (string, error) { - if dir == "" || dir == "/" { - return "", fmt.Errorf("module root not found") - } - - _, err := os.Stat(filepath.Join(dir, "go.mod")) - if err != nil { - return hasGoMod(filepath.Dir(dir)) - } - return dir, nil -} - -func prependToPath(additionalPaths ...string) string { - paths := strings.Split(os.Getenv("PATH"), ":") - paths = append(additionalPaths, paths...) - - return strings.Join(paths, ":") -} - // case 1) nvidia-container-runtime run --bundle // case 2) nvidia-container-runtime create --bundle // - Confirm the runtime handles bad input correctly diff --git a/internal/oci/spec_test.go b/internal/oci/spec_test.go index 20fc97b3..03d2d301 100644 --- a/internal/oci/spec_test.go +++ b/internal/oci/spec_test.go @@ -1,17 +1,16 @@ package oci import ( - "fmt" "os" "path/filepath" - "runtime" "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/test" "github.com/stretchr/testify/require" ) func TestMaintainSpec(t *testing.T) { - moduleRoot, err := getModuleRoot() + moduleRoot, err := test.GetModuleRoot() require.NoError(t, err) files := []string{ @@ -38,21 +37,3 @@ func TestMaintainSpec(t *testing.T) { require.JSONEq(t, string(inputContents), string(outputContents)) } } - -func getModuleRoot() (string, error) { - _, filename, _, _ := runtime.Caller(0) - - return hasGoMod(filename) -} - -func hasGoMod(dir string) (string, error) { - if dir == "" || dir == "/" { - return "", fmt.Errorf("module root not found") - } - - _, err := os.Stat(filepath.Join(dir, "go.mod")) - if err != nil { - return hasGoMod(filepath.Dir(dir)) - } - return dir, nil -} diff --git a/internal/test/test.go b/internal/test/test.go new file mode 100644 index 00000000..77c4a84c --- /dev/null +++ b/internal/test/test.go @@ -0,0 +1,52 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package test + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" +) + +// GetModuleRoot returns the path to the root of the go module +func GetModuleRoot() (string, error) { + _, filename, _, _ := runtime.Caller(0) + + return hasGoMod(filename) +} + +// PrependToPath prefixes the specified additional paths to the PATH environment variable +func PrependToPath(additionalPaths ...string) string { + paths := strings.Split(os.Getenv("PATH"), ":") + paths = append(additionalPaths, paths...) + + return strings.Join(paths, ":") +} + +func hasGoMod(dir string) (string, error) { + if dir == "" || dir == "/" { + return "", fmt.Errorf("module root not found") + } + + _, err := os.Stat(filepath.Join(dir, "go.mod")) + if err != nil { + return hasGoMod(filepath.Dir(dir)) + } + return dir, nil +}