diff --git a/pkg/container_test.go b/pkg/container_test.go index 4fec7dc3..b3c258f7 100644 --- a/pkg/container_test.go +++ b/pkg/container_test.go @@ -2,8 +2,9 @@ package main import ( "path/filepath" - "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestGetNvidiaConfig(t *testing.T) { @@ -414,7 +415,7 @@ func TestGetNvidiaConfig(t *testing.T) { // For any tests that are expected to panic, make sure they do. if tc.expectedPanic { - mustPanic(t, getConfig) + require.Panics(t, getConfig) return } @@ -422,31 +423,20 @@ func TestGetNvidiaConfig(t *testing.T) { getConfig() // And start comparing the test results to the expected results. - if config == nil && tc.expectedConfig == nil { + if tc.expectedConfig == nil { + require.Nil(t, config, tc.description) return } - if config != nil && tc.expectedConfig != nil { - if !reflect.DeepEqual(config.Devices, tc.expectedConfig.Devices) { - t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig) - } - if !reflect.DeepEqual(config.MigConfigDevices, tc.expectedConfig.MigConfigDevices) { - t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig) - } - if !reflect.DeepEqual(config.MigMonitorDevices, tc.expectedConfig.MigMonitorDevices) { - t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig) - } - if !reflect.DeepEqual(config.DriverCapabilities, tc.expectedConfig.DriverCapabilities) { - t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig) - } - if !elementsMatch(config.Requirements, tc.expectedConfig.Requirements) { - t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig) - } - if !reflect.DeepEqual(config.DisableRequire, tc.expectedConfig.DisableRequire) { - t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig) - } - return - } - t.Errorf("Unexpected nvidiaConfig (got: %v, wanted: %v)", config, tc.expectedConfig) + + require.NotNil(t, config, tc.description) + + require.Equal(t, tc.expectedConfig.Devices, config.Devices) + require.Equal(t, tc.expectedConfig.MigConfigDevices, config.MigConfigDevices) + require.Equal(t, tc.expectedConfig.MigMonitorDevices, config.MigMonitorDevices) + require.Equal(t, tc.expectedConfig.DriverCapabilities, config.DriverCapabilities) + + require.ElementsMatch(t, tc.expectedConfig.Requirements, config.Requirements) + require.Equal(t, tc.expectedConfig.DisableRequire, config.DisableRequire) }) } } @@ -524,9 +514,7 @@ func TestGetDevicesFromMounts(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { devices := getDevicesFromMounts(tc.mounts) - if !reflect.DeepEqual(devices, tc.expectedDevices) { - t.Errorf("Unexpected devices (got: %v, wanted: %v)", *devices, *tc.expectedDevices) - } + require.Equal(t, tc.expectedDevices, devices) }) } } @@ -639,36 +627,8 @@ func TestDeviceListSourcePriority(t *testing.T) { // For all other tests, just grab the devices and check the results getDevices() - if !reflect.DeepEqual(devices, tc.expectedDevices) { - t.Errorf("Unexpected devices (got: %v, wanted: %v)", *devices, *tc.expectedDevices) - } + + require.Equal(t, tc.expectedDevices, devices) }) } } - -func elementsMatch(slice0, slice1 []string) bool { - map0 := make(map[string]int) - map1 := make(map[string]int) - - for _, e := range slice0 { - map0[e]++ - } - - for _, e := range slice1 { - map1[e]++ - } - - for k0, v0 := range map0 { - if map1[k0] != v0 { - return false - } - } - - for k1, v1 := range map1 { - if map0[k1] != v1 { - return false - } - } - - return true -} diff --git a/pkg/hook_test.go b/pkg/hook_test.go index 1788cdae..07d65bb4 100644 --- a/pkg/hook_test.go +++ b/pkg/hook_test.go @@ -3,6 +3,8 @@ package main import ( "encoding/json" "testing" + + "github.com/stretchr/testify/require" ) func TestParseCudaVersionValid(t *testing.T) { @@ -16,24 +18,15 @@ func TestParseCudaVersionValid(t *testing.T) { {"9.0.116", [3]uint32{9, 0, 116}}, {"4294967295.4294967295.4294967295", [3]uint32{4294967295, 4294967295, 4294967295}}, } - for _, c := range tests { + for i, c := range tests { vmaj, vmin, vpatch := parseCudaVersion(c.version) - if vmaj != c.expected[0] || vmin != c.expected[1] || vpatch != c.expected[2] { - t.Errorf("parseCudaVersion(%s): %d.%d.%d (expected: %v)", c.version, vmaj, vmin, vpatch, c.expected) - } + + version := [3]uint32{vmaj, vmin, vpatch} + + require.Equal(t, c.expected, version, "%d: %v", i, c) } } -func mustPanic(t *testing.T, f func()) { - defer func() { - if err := recover(); err == nil { - t.Error("Test didn't panic!") - } - }() - - f() -} - func TestParseCudaVersionInvalid(t *testing.T) { var tests = []string{ "foo", @@ -53,10 +46,9 @@ func TestParseCudaVersionInvalid(t *testing.T) { "-9.-1.-116", } for _, c := range tests { - mustPanic(t, func() { - t.Logf("parseCudaVersion(%s)", c) + require.Panics(t, func() { parseCudaVersion(c) - }) + }, "parseCudaVersion(%v)", c) } } @@ -132,12 +124,11 @@ func TestIsPrivileged(t *testing.T) { false, }, } - for _, tc := range tests { + for i, tc := range tests { var spec Spec _ = json.Unmarshal([]byte(tc.spec), &spec) privileged := isPrivileged(&spec) - if privileged != tc.expected { - t.Errorf("isPrivileged() returned unexpectred value (privileged: %v, tc.expected: %v)", privileged, tc.expected) - } + + require.Equal(t, tc.expected, privileged, "%d: %v", i, tc) } }