Add test package with GetModuleRoot and PrependToPath function

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2022-03-10 13:37:08 +02:00
parent c5c2ffd68f
commit bf8c3bab72
3 changed files with 57 additions and 50 deletions

View File

@ -3,15 +3,14 @@ package main
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -35,7 +34,7 @@ func TestMain(m *testing.M) {
// TEST SETUP // TEST SETUP
// Determine the module root and the test binary path // Determine the module root and the test binary path
var err error var err error
moduleRoot, err := getModuleRoot() moduleRoot, err := test.GetModuleRoot()
if err != nil { if err != nil {
logger.Fatalf("error in test setup: could not get module root: %v", err) logger.Fatalf("error in test setup: could not get module root: %v", err)
} }
@ -43,7 +42,7 @@ func TestMain(m *testing.M) {
testInputPath := filepath.Join(moduleRoot, "test", "input") testInputPath := filepath.Join(moduleRoot, "test", "input")
// Set the environment variables for the test // Set the environment variables for the test
os.Setenv("PATH", prependToPath(testBinPath, moduleRoot)) os.Setenv("PATH", test.PrependToPath(testBinPath, moduleRoot))
os.Setenv("XDG_CONFIG_HOME", testInputPath) os.Setenv("XDG_CONFIG_HOME", testInputPath)
// Confirm that the environment is configured correctly // Confirm that the environment is configured correctly
@ -71,31 +70,6 @@ func TestMain(m *testing.M) {
os.Exit(exitCode) os.Exit(exitCode)
} }
func getModuleRoot() (string, error) {
_, filename, _, _ := runtime.Caller(0)
return hasGoMod(filename)
}
func hasGoMod(dir string) (string, error) {
if dir == "" || dir == "/" {
return "", fmt.Errorf("module root not found")
}
_, err := os.Stat(filepath.Join(dir, "go.mod"))
if err != nil {
return hasGoMod(filepath.Dir(dir))
}
return dir, nil
}
func prependToPath(additionalPaths ...string) string {
paths := strings.Split(os.Getenv("PATH"), ":")
paths = append(additionalPaths, paths...)
return strings.Join(paths, ":")
}
// case 1) nvidia-container-runtime run --bundle // case 1) nvidia-container-runtime run --bundle
// case 2) nvidia-container-runtime create --bundle // case 2) nvidia-container-runtime create --bundle
// - Confirm the runtime handles bad input correctly // - Confirm the runtime handles bad input correctly

View File

@ -1,17 +1,16 @@
package oci package oci
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"testing" "testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMaintainSpec(t *testing.T) { func TestMaintainSpec(t *testing.T) {
moduleRoot, err := getModuleRoot() moduleRoot, err := test.GetModuleRoot()
require.NoError(t, err) require.NoError(t, err)
files := []string{ files := []string{
@ -38,21 +37,3 @@ func TestMaintainSpec(t *testing.T) {
require.JSONEq(t, string(inputContents), string(outputContents)) require.JSONEq(t, string(inputContents), string(outputContents))
} }
} }
func getModuleRoot() (string, error) {
_, filename, _, _ := runtime.Caller(0)
return hasGoMod(filename)
}
func hasGoMod(dir string) (string, error) {
if dir == "" || dir == "/" {
return "", fmt.Errorf("module root not found")
}
_, err := os.Stat(filepath.Join(dir, "go.mod"))
if err != nil {
return hasGoMod(filepath.Dir(dir))
}
return dir, nil
}

52
internal/test/test.go Normal file
View File

@ -0,0 +1,52 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package test
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
)
// GetModuleRoot returns the path to the root of the go module
func GetModuleRoot() (string, error) {
_, filename, _, _ := runtime.Caller(0)
return hasGoMod(filename)
}
// PrependToPath prefixes the specified additional paths to the PATH environment variable
func PrependToPath(additionalPaths ...string) string {
paths := strings.Split(os.Getenv("PATH"), ":")
paths = append(additionalPaths, paths...)
return strings.Join(paths, ":")
}
func hasGoMod(dir string) (string, error) {
if dir == "" || dir == "/" {
return "", fmt.Errorf("module root not found")
}
_, err := os.Stat(filepath.Join(dir, "go.mod"))
if err != nil {
return hasGoMod(filepath.Dir(dir))
}
return dir, nil
}