This commit is contained in:
kavioyu
2025-03-13 07:04:56 +00:00
parent 6e53c6613d
commit 094d0421ec
3 changed files with 4 additions and 7 deletions

View File

@@ -109,7 +109,6 @@ def test_gemm_backward_w() -> None:
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
torch.cuda.synchronize()
print(diff)
# noinspection PyShadowingNames
def test_func():
@@ -195,5 +194,5 @@ if __name__ == '__main__':
test_gemm_backward_w()
test_gemm()
# test_m_grouped_gemm_contiguous()
# test_m_grouped_gemm_masked()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()