mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
update unitest
This commit is contained in:
parent
7db1b0ef63
commit
e29e996a42
@ -374,25 +374,26 @@ def test_k_grouped_wgrad_gemm():
|
||||
def test_m_grouped_gemm_offset() -> None:
|
||||
print('Testing grouped offset GEMM:')
|
||||
|
||||
for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096),(32, 32, 7168, 4096)):
|
||||
# NOTES: we should mask the unfilled part before calculating difference
|
||||
ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n)
|
||||
pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1])
|
||||
for num_groups, expected_m_per_group in ((2, 16), (4, 16), (2, 32), (9, 32), (2, 32), (4, 32), (32, 64)):
|
||||
for k, n in ((7168, 4096),):
|
||||
# NOTES: we should mask the unfilled part before calculating difference
|
||||
ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n)
|
||||
pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1])
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group)
|
||||
diff = calc_diff(out_offset, ref_out_offset)
|
||||
assert diff < 0.001.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group)
|
||||
diff = calc_diff(out_offset, ref_out_offset)
|
||||
assert diff < 0.001, f'{m_offset=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
valid_m = m_offset
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group)
|
||||
|
||||
print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
valid_m = m_offset
|
||||
|
||||
print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user