mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 00:08:11 +00:00
Use require package for tests
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
b930487dc5
commit
602eaf0e60
@ -2,8 +2,9 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetNvidiaConfig(t *testing.T) {
|
func TestGetNvidiaConfig(t *testing.T) {
|
||||||
@ -414,7 +415,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
|
|
||||||
// For any tests that are expected to panic, make sure they do.
|
// For any tests that are expected to panic, make sure they do.
|
||||||
if tc.expectedPanic {
|
if tc.expectedPanic {
|
||||||
mustPanic(t, getConfig)
|
require.Panics(t, getConfig)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -422,31 +423,20 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
getConfig()
|
getConfig()
|
||||||
|
|
||||||
// And start comparing the test results to the expected results.
|
// And start comparing the test results to the expected results.
|
||||||
if config == nil && tc.expectedConfig == nil {
|
if tc.expectedConfig == nil {
|
||||||
|
require.Nil(t, config, tc.description)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if config != nil && tc.expectedConfig != nil {
|
|
||||||
if !reflect.DeepEqual(config.Devices, tc.expectedConfig.Devices) {
|
require.NotNil(t, config, tc.description)
|
||||||
t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig)
|
|
||||||
}
|
require.Equal(t, tc.expectedConfig.Devices, config.Devices)
|
||||||
if !reflect.DeepEqual(config.MigConfigDevices, tc.expectedConfig.MigConfigDevices) {
|
require.Equal(t, tc.expectedConfig.MigConfigDevices, config.MigConfigDevices)
|
||||||
t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig)
|
require.Equal(t, tc.expectedConfig.MigMonitorDevices, config.MigMonitorDevices)
|
||||||
}
|
require.Equal(t, tc.expectedConfig.DriverCapabilities, config.DriverCapabilities)
|
||||||
if !reflect.DeepEqual(config.MigMonitorDevices, tc.expectedConfig.MigMonitorDevices) {
|
|
||||||
t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig)
|
require.ElementsMatch(t, tc.expectedConfig.Requirements, config.Requirements)
|
||||||
}
|
require.Equal(t, tc.expectedConfig.DisableRequire, config.DisableRequire)
|
||||||
if !reflect.DeepEqual(config.DriverCapabilities, tc.expectedConfig.DriverCapabilities) {
|
|
||||||
t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig)
|
|
||||||
}
|
|
||||||
if !elementsMatch(config.Requirements, tc.expectedConfig.Requirements) {
|
|
||||||
t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(config.DisableRequire, tc.expectedConfig.DisableRequire) {
|
|
||||||
t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -524,9 +514,7 @@ func TestGetDevicesFromMounts(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.description, func(t *testing.T) {
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
devices := getDevicesFromMounts(tc.mounts)
|
devices := getDevicesFromMounts(tc.mounts)
|
||||||
if !reflect.DeepEqual(devices, tc.expectedDevices) {
|
require.Equal(t, tc.expectedDevices, devices)
|
||||||
t.Errorf("Unexpected devices (got: %v, wanted: %v)", *devices, *tc.expectedDevices)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -639,36 +627,8 @@ func TestDeviceListSourcePriority(t *testing.T) {
|
|||||||
|
|
||||||
// For all other tests, just grab the devices and check the results
|
// For all other tests, just grab the devices and check the results
|
||||||
getDevices()
|
getDevices()
|
||||||
if !reflect.DeepEqual(devices, tc.expectedDevices) {
|
|
||||||
t.Errorf("Unexpected devices (got: %v, wanted: %v)", *devices, *tc.expectedDevices)
|
require.Equal(t, tc.expectedDevices, devices)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func elementsMatch(slice0, slice1 []string) bool {
|
|
||||||
map0 := make(map[string]int)
|
|
||||||
map1 := make(map[string]int)
|
|
||||||
|
|
||||||
for _, e := range slice0 {
|
|
||||||
map0[e]++
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range slice1 {
|
|
||||||
map1[e]++
|
|
||||||
}
|
|
||||||
|
|
||||||
for k0, v0 := range map0 {
|
|
||||||
if map1[k0] != v0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for k1, v1 := range map1 {
|
|
||||||
if map0[k1] != v1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
@ -3,6 +3,8 @@ package main
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseCudaVersionValid(t *testing.T) {
|
func TestParseCudaVersionValid(t *testing.T) {
|
||||||
@ -16,24 +18,15 @@ func TestParseCudaVersionValid(t *testing.T) {
|
|||||||
{"9.0.116", [3]uint32{9, 0, 116}},
|
{"9.0.116", [3]uint32{9, 0, 116}},
|
||||||
{"4294967295.4294967295.4294967295", [3]uint32{4294967295, 4294967295, 4294967295}},
|
{"4294967295.4294967295.4294967295", [3]uint32{4294967295, 4294967295, 4294967295}},
|
||||||
}
|
}
|
||||||
for _, c := range tests {
|
for i, c := range tests {
|
||||||
vmaj, vmin, vpatch := parseCudaVersion(c.version)
|
vmaj, vmin, vpatch := parseCudaVersion(c.version)
|
||||||
if vmaj != c.expected[0] || vmin != c.expected[1] || vpatch != c.expected[2] {
|
|
||||||
t.Errorf("parseCudaVersion(%s): %d.%d.%d (expected: %v)", c.version, vmaj, vmin, vpatch, c.expected)
|
version := [3]uint32{vmaj, vmin, vpatch}
|
||||||
}
|
|
||||||
|
require.Equal(t, c.expected, version, "%d: %v", i, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustPanic(t *testing.T, f func()) {
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err == nil {
|
|
||||||
t.Error("Test didn't panic!")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
f()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseCudaVersionInvalid(t *testing.T) {
|
func TestParseCudaVersionInvalid(t *testing.T) {
|
||||||
var tests = []string{
|
var tests = []string{
|
||||||
"foo",
|
"foo",
|
||||||
@ -53,10 +46,9 @@ func TestParseCudaVersionInvalid(t *testing.T) {
|
|||||||
"-9.-1.-116",
|
"-9.-1.-116",
|
||||||
}
|
}
|
||||||
for _, c := range tests {
|
for _, c := range tests {
|
||||||
mustPanic(t, func() {
|
require.Panics(t, func() {
|
||||||
t.Logf("parseCudaVersion(%s)", c)
|
|
||||||
parseCudaVersion(c)
|
parseCudaVersion(c)
|
||||||
})
|
}, "parseCudaVersion(%v)", c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,12 +124,11 @@ func TestIsPrivileged(t *testing.T) {
|
|||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tests {
|
for i, tc := range tests {
|
||||||
var spec Spec
|
var spec Spec
|
||||||
_ = json.Unmarshal([]byte(tc.spec), &spec)
|
_ = json.Unmarshal([]byte(tc.spec), &spec)
|
||||||
privileged := isPrivileged(&spec)
|
privileged := isPrivileged(&spec)
|
||||||
if privileged != tc.expected {
|
|
||||||
t.Errorf("isPrivileged() returned unexpectred value (privileged: %v, tc.expected: %v)", privileged, tc.expected)
|
require.Equal(t, tc.expected, privileged, "%d: %v", i, tc)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user