mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
tested
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import (
|
||||
bench_kineto,
|
||||
calc_diff,
|
||||
ceil_div,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
)
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -33,8 +39,18 @@ def construct(m: int, k: int, n: int) -> \
|
||||
ref_out = x @ y.t()
|
||||
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
def construct_backward_w(m: int, k: int, n: int) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_out = x @ y.t()
|
||||
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_token_cast_to_fp8(y)
|
||||
#x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
||||
#y_fp8 = (y_fp8[0], get_col_major_tma_aligned_tensor(y_fp8[1]))
|
||||
return x_fp8, y_fp8, out, ref_out
|
||||
|
||||
|
||||
@@ -84,6 +100,30 @@ def test_gemm() -> None:
|
||||
print()
|
||||
|
||||
|
||||
def test_gemm_backward_w() -> None:
|
||||
print('Testing GEMM Backward W:')
|
||||
for m in (64, 128, 4096):
|
||||
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
||||
x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out)
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
torch.cuda.synchronize()
|
||||
print(diff)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_backward_w(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_bw_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm_bw', suppress_kineto_output=True)
|
||||
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
||||
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing grouped contiguous GEMM:')
|
||||
|
||||
@@ -153,6 +193,7 @@ if __name__ == '__main__':
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm_backward_w()
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
# test_m_grouped_gemm_contiguous()
|
||||
# test_m_grouped_gemm_masked()
|
||||
|
||||
Reference in New Issue
Block a user