mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-03 17:40:42 +00:00
Correctly flush L2, as reconstructing the tensors on every iteration effectively put them in the L2, and gave the GPU enough idle time to avoid thermal throttling in a potentially unrealistic way.
The previous behaviour is potentially representative of some use cases (e.g. previous kernel filling L2 with the data in a very specific way) but not standard benchmarking practice.
This commit is contained in:
parent
e1c070fbef
commit
6cbff5778f
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@ -77,10 +78,28 @@ class suppress_stdout_stderr:
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False):
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
|
||||
|
||||
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
|
||||
# this avoid thermal throttling while keeping DVFS at max clocks (slight gain vs sleep / more consistent on GH200)
|
||||
sleep_between_tests = 0.0
|
||||
flush_l2_size = int(8e9 // 4)
|
||||
if os.environ.get('DG_BENCH_DISABLE_L2_FLUSH', False):
|
||||
flush_l2 = False
|
||||
if os.environ.get('DG_BENCH_POWER_LIMITED', False):
|
||||
# if we want to be thermally limited, we need to run many iterations non-stop for a fairly long time
|
||||
# and spend as little time as possible doing memset and other setup work (80MiB should be enough to flush L2)
|
||||
num_tests = 2000
|
||||
flush_l2_size = int(80e6 // 4)
|
||||
sleep_val = os.environ.get('DG_BENCH_SLEEP_BETWEEN_TESTS', False)
|
||||
if sleep_val:
|
||||
try:
|
||||
sleep_between_tests = float(sleep_val)
|
||||
except ValueError:
|
||||
pass # Keep default
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
@ -98,8 +117,10 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
||||
for _ in range(num_tests):
|
||||
if sleep_between_tests > 0.0:
|
||||
time.sleep(sleep_between_tests)
|
||||
if flush_l2:
|
||||
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
|
@ -71,10 +71,11 @@ def test_gemm() -> None:
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
@ -96,12 +97,13 @@ def test_m_grouped_gemm_contiguous() -> None:
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
||||
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
||||
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
@ -129,11 +131,12 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
# Construct new tensors every time to avoid L2 cache acceleration
|
||||
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
||||
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
|
Loading…
Reference in New Issue
Block a user