This commit is contained in:
kavioyu
2025-03-12 13:48:02 +00:00
parent 9d3222a93e
commit 6e53c6613d
6 changed files with 678 additions and 6 deletions

View File

@@ -1,9 +1,15 @@
import random
import torch
from typing import Tuple
import torch
import deep_gemm
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
from deep_gemm import (
bench_kineto,
calc_diff,
ceil_div,
get_col_major_tma_aligned_tensor,
)
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -33,8 +39,18 @@ def construct(m: int, k: int, n: int) -> \
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def construct_backward_w(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_token_cast_to_fp8(y)
#x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
#y_fp8 = (y_fp8[0], get_col_major_tma_aligned_tensor(y_fp8[1]))
return x_fp8, y_fp8, out, ref_out
@@ -84,6 +100,30 @@ def test_gemm() -> None:
print()
def test_gemm_backward_w() -> None:
print('Testing GEMM Backward W:')
for m in (64, 128, 4096):
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
torch.cuda.synchronize()
print(diff)
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out)
t = bench_kineto(test_func, 'fp8_gemm_bw', suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_contiguous() -> None:
print('Testing grouped contiguous GEMM:')
@@ -153,6 +193,7 @@ if __name__ == '__main__':
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm_backward_w()
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
# test_m_grouped_gemm_contiguous()
# test_m_grouped_gemm_masked()