Several cleanups

This commit is contained in:
Chenggang Zhao 2025-05-14 14:18:43 +08:00
parent 6233709c67
commit c4a7116e0a
7 changed files with 38 additions and 49 deletions

View File

@ -8,6 +8,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
## News
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.
@ -22,9 +23,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] NVRTC as a faster compiler
- [ ] Stolen JIT cache
- [ ] Sanitizer for testing
- [ ] Weight gradient kernels for dense models
- [ ] Weight gradient kernels for MoE models
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)
- [x] Weight gradient kernels for dense models
- [x] Weight gradient kernels for MoE models
- [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang))
- [ ] CUDA PDL support
- [ ] More scaling granularity support via templates
- [ ] Larger TMA multicast size for some shapes

View File

@ -352,7 +352,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
}
}
#else
if (blockIdx.x == 0 && threadIdx.x == 0)
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false && "This kernel only support sm_90a");
#endif
}

View File

@ -36,19 +36,20 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
assert block_k == 128
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
if is_wgrad:
smem_scales_b_per_stage = ceil_div(block_n * 4, 128) * 128
else:
smem_scales_b = ceil_div(k, block_k) * 4
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
smem_barrier = num_stages * 8 * 2
smem_size = 0
@ -56,11 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
if is_wgrad:
smem_size += num_stages * smem_scales_b_per_stage
else:
smem_size += ceil_div(smem_scales_b * (1 if block_k %
block_n == 0 else 2), 8) * 8
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
@ -80,7 +78,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
# Avoid bank conflicts for fp32 output
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]

View File

@ -269,7 +269,7 @@ static void __instantiate_kernel() {{
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
class FP8WgradGemmRuntime(Runtime):
class FP8WGradGemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'NUM_TMA_MULTICAST',
@ -320,7 +320,7 @@ static void __instantiate_kernel() {{
}};
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 Wgrad GEMM code:\n{code}')
print(f'Generated FP8 WGrad GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding

View File

@ -1,9 +1,8 @@
import math
import torch
from typing import List, Tuple
from .runtime import (
FP8WgradGemmRuntime, GemmType,
FP8WGradGemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
from .gemm import get_best_configs
@ -122,7 +121,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
runtime_cls=FP8WgradGemmRuntime,
runtime_cls=FP8WGradGemmRuntime,
)
# Run the kernel

View File

@ -78,7 +78,8 @@ class suppress_stdout_stderr:
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, is_multiple: bool = False):
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
with_multiple_kernels: bool = False):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))
@ -119,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
if not is_multiple:
if not with_multiple_kernels:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
@ -131,16 +132,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
if not is_multiple:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
else:
total_time = 0
total_num = 0
for line in prof_lines:

View File

@ -291,7 +291,7 @@ def test_k_grouped_wgrad_gemm():
def test_func():
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, is_multiple=True) * num_groups
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups
print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, '
f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s')