update unitest

This commit is contained in:
wangzhe_ant 2025-06-24 18:24:08 +08:00
parent ccd63bb234
commit 7db1b0ef63

View File

@ -372,16 +372,16 @@ def test_k_grouped_wgrad_gemm():
def test_m_grouped_gemm_offset() -> None: def test_m_grouped_gemm_offset() -> None:
print('Testing grouped contiguous GEMM:') print('Testing grouped offset GEMM:')
for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096)): for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096),(32, 32, 7168, 4096)):
# NOTES: we should mask the unfilled part before calculating difference # NOTES: we should mask the unfilled part before calculating difference
ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n)
pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1]) pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1])
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group)
diff = calc_diff(out_offset, ref_out_offset) diff = calc_diff(out_offset, ref_out_offset)
assert diff < 0.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}' assert diff < 0.001.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames # noinspection PyShadowingNames
def test_func(): def test_func():
@ -390,7 +390,7 @@ def test_m_grouped_gemm_offset() -> None:
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
valid_m = m_offset valid_m = m_offset
print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
print() print()