mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Fix tests
This commit is contained in:
parent
81de208430
commit
780b4098e4
@ -50,7 +50,7 @@ def construct(m: int, k: int, n: int) -> \
|
|||||||
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
|
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
|
||||||
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
alignment = get_m_alignment_for_contiguous_layout()
|
alignment = get_m_alignment_for_contiguous_layout()
|
||||||
group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3))]
|
group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
|
||||||
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
|
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
|
||||||
|
|
||||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||||
@ -180,7 +180,7 @@ def test_gemm() -> None:
|
|||||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||||
|
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
@ -189,7 +189,9 @@ def test_gemm() -> None:
|
|||||||
def test_m_grouped_gemm_contiguous() -> None:
|
def test_m_grouped_gemm_contiguous() -> None:
|
||||||
print('Testing grouped contiguous GEMM:')
|
print('Testing grouped contiguous GEMM:')
|
||||||
|
|
||||||
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168),
|
||||||
|
(8, 4096, 7168, 4096), (8, 4096, 2048, 7168),
|
||||||
|
(32, 256, 7168, 4096), (32, 256, 2048, 7168)):
|
||||||
# NOTES: we should mask the unfilled part before calculating difference
|
# NOTES: we should mask the unfilled part before calculating difference
|
||||||
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||||
@ -203,7 +205,7 @@ def test_m_grouped_gemm_contiguous() -> None:
|
|||||||
|
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
valid_m = (m_indices != -1).sum().item()
|
valid_m = (m_indices != -1).sum().item()
|
||||||
print(f' > Performance ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
@ -231,7 +233,7 @@ def test_m_grouped_gemm_masked() -> None:
|
|||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
valid_m = masked_m.sum().item()
|
valid_m = masked_m.sum().item()
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
print(f' > Performance ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user