diff --git a/tests/test_core.py b/tests/test_core.py index 41c9069..bdc1841 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -40,7 +40,7 @@ def construct(m: int, k: int, n: int) -> \ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ - Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: m = 0 m_aligned = get_m_alignment_for_contiguous_layout() group_m_list = []