mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor tests
This commit is contained in:
parent
c4e31d121b
commit
1169f83c36
@ -49,13 +49,10 @@ def construct(m: int, k: int, n: int) -> \
|
|||||||
|
|
||||||
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
|
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
|
||||||
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
m = 0
|
alignment = get_m_alignment_for_contiguous_layout()
|
||||||
m_aligned = get_m_alignment_for_contiguous_layout()
|
group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3))]
|
||||||
group_m_list = []
|
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
|
||||||
for i in range(num_groups):
|
|
||||||
group_m = random.randint(int(expected_m_per_group * 0.7), int(expected_m_per_group * 1.3))
|
|
||||||
m += ceil_div(group_m, m_aligned) * m_aligned
|
|
||||||
group_m_list.append(group_m)
|
|
||||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||||
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
|
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
|
||||||
@ -63,13 +60,14 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
|||||||
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
|
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
start = 0
|
start = 0
|
||||||
for i, group_m in enumerate(group_m_list):
|
for i, group_m in enumerate(group_ms):
|
||||||
actual_end = start + group_m
|
actual_end = start + group_m
|
||||||
aligned_end = start + ceil_div(group_m, m_aligned) * m_aligned
|
aligned_end = start + ceil_div(group_m, alignment) * alignment
|
||||||
m_indices[start:actual_end] = i
|
m_indices[start:actual_end] = i
|
||||||
m_indices[actual_end:aligned_end] = -1
|
m_indices[actual_end:aligned_end] = -1
|
||||||
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
|
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
|
||||||
start = aligned_end
|
start = aligned_end
|
||||||
|
ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out)
|
||||||
|
|
||||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
assert m % 4 == 0, f'TMA alignment error: {m}'
|
||||||
x_fp8 = per_token_cast_to_fp8(x)
|
x_fp8 = per_token_cast_to_fp8(x)
|
||||||
@ -80,15 +78,15 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k:
|
|||||||
return m, x_fp8, y_fp8, m_indices, out, ref_out
|
return m, x_fp8, y_fp8, m_indices, out, ref_out
|
||||||
|
|
||||||
|
|
||||||
def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \
|
||||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
x = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
|
||||||
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||||
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
|
out = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
|
||||||
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
|
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
|
||||||
|
|
||||||
assert m % 4 == 0, f'TMA alignment error: {m}'
|
assert max_m % 4 == 0, f'TMA alignment error: {max_m}'
|
||||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, k // 128), device='cuda', dtype=torch.float))
|
||||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
||||||
@ -96,7 +94,12 @@ def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \
|
|||||||
|
|
||||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||||
return x_fp8, y_fp8, out, ref_out
|
|
||||||
|
# Construct mask
|
||||||
|
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||||
|
for j in range(num_groups):
|
||||||
|
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||||
|
return x_fp8, y_fp8, masked_m, out, ref_out
|
||||||
|
|
||||||
|
|
||||||
def construct_wgrad(m: int, k: int, n: int) -> \
|
def construct_wgrad(m: int, k: int, n: int) -> \
|
||||||
@ -172,9 +175,6 @@ def test_gemm() -> None:
|
|||||||
diff = calc_diff(out, ref_out)
|
diff = calc_diff(out, ref_out)
|
||||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||||
|
|
||||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
|
||||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
|
||||||
|
|
||||||
# noinspection PyShadowingNames
|
# noinspection PyShadowingNames
|
||||||
def test_func():
|
def test_func():
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||||
@ -190,63 +190,50 @@ def test_m_grouped_gemm_contiguous() -> None:
|
|||||||
print('Testing grouped contiguous GEMM:')
|
print('Testing grouped contiguous GEMM:')
|
||||||
|
|
||||||
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
||||||
# TODO: make a stronger test
|
# NOTES: we should mask the unfilled part before calculating difference
|
||||||
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||||
out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out)
|
out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out)
|
||||||
ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out)
|
|
||||||
diff = calc_diff(out, ref_out)
|
diff = calc_diff(out, ref_out)
|
||||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||||
|
|
||||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
|
||||||
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
|
|
||||||
|
|
||||||
# noinspection PyShadowingNames
|
# noinspection PyShadowingNames
|
||||||
def test_func():
|
def test_func():
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||||
|
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
sum_m = (m_indices != -1).sum().item()
|
valid_m = (m_indices != -1).sum().item()
|
||||||
print(f' > Performance ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > Performance ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * sum_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(sum_m * k + num_groups * k * n + sum_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def test_m_grouped_gemm_masked() -> None:
|
def test_m_grouped_gemm_masked() -> None:
|
||||||
print('Testing grouped masked GEMM:')
|
print('Testing grouped masked GEMM:')
|
||||||
|
|
||||||
m = 4096
|
max_m = 4096
|
||||||
for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)):
|
for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)):
|
||||||
for k, n in ((7168, 4096), (2048, 7168), ):
|
for k, n in ((7168, 4096), (2048, 7168), ):
|
||||||
# Test correctness
|
# Test correctness
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, max_m, expected_m_per_group, k, n)
|
||||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group)
|
||||||
for j in range(num_groups):
|
|
||||||
masked_m[j] = random.randint(int(expected_m_per_group * 0.7), int(expected_m_per_group * 1.3))
|
|
||||||
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
|
||||||
for j in range(num_groups):
|
for j in range(num_groups):
|
||||||
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
||||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
assert diff < 0.001, f'{max_m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||||
|
|
||||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
|
||||||
x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n)
|
|
||||||
for j in range(num_groups):
|
|
||||||
masked_m[j] = random.randint(int(expected_m_per_group * 0.7), int(expected_m_per_group * 1.3))
|
|
||||||
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
|
||||||
sum_m = masked_m.sum().item()
|
|
||||||
|
|
||||||
# noinspection PyShadowingNames
|
# noinspection PyShadowingNames
|
||||||
def test_func():
|
def test_func():
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group)
|
||||||
|
|
||||||
# Test performance with fixed shapes
|
# Test performance with fixed shapes
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
valid_m = masked_m.sum().item()
|
||||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||||
print(f' > Performance ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > Performance ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * sum_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(sum_m * k + num_groups * k * n + sum_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user