diff --git a/README.md b/README.md index f14601c..2aa53ce 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## News +- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. @@ -22,9 +23,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] NVRTC as a faster compiler - [ ] Stolen JIT cache - [ ] Sanitizer for testing -- [ ] Weight gradient kernels for dense models -- [ ] Weight gradient kernels for MoE models -- [ ] Utility kernels for MoE models (as a pre-built CUDA library) +- [x] Weight gradient kernels for dense models +- [x] Weight gradient kernels for MoE models +- [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang)) - [ ] CUDA PDL support - [ ] More scaling granularity support via templates - [ ] Larger TMA multicast size for some shapes diff --git a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh index ffb2926..4bf179e 100644 --- a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh @@ -23,11 +23,11 @@ template __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) fp8_wgrad_gemm_kernel(uint32_t shape_k, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_scales_a, - const __grid_constant__ CUtensorMap tensor_map_scales_b, - const __grid_constant__ CUtensorMap tensor_map_d) { + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_scales_a, + const __grid_constant__ CUtensorMap tensor_map_scales_b, + const __grid_constant__ CUtensorMap tensor_map_d) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); @@ -352,7 +352,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k, } } #else - if (blockIdx.x == 0 && threadIdx.x == 0) + if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false && "This kernel only support sm_90a"); #endif } diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 16cde20..cad4fb1 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -36,19 +36,20 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: + assert block_k == 128 + # Try swizzle first, as it does not waste shared memory swizzle_mode = get_swizzle_mode(block_n) block_n_padding = get_block_n_padding_for_smem_d( block_n) if swizzle_mode == 0 else 0 + # NOTES: `scales_b` in a total manner or per-stage manner smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = block_m * 4 smem_b_per_stage = block_n * block_k - if is_wgrad: - smem_scales_b_per_stage = ceil_div(block_n * 4, 128) * 128 - else: - smem_scales_b = ceil_div(k, block_k) * 4 + smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 + smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 smem_barrier = num_stages * 8 * 2 smem_size = 0 @@ -56,11 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_b_per_stage - if is_wgrad: - smem_size += num_stages * smem_scales_b_per_stage - else: - smem_size += ceil_div(smem_scales_b * (1 if block_k % - block_n == 0 else 2), 8) * 8 + smem_size += num_stages * smem_scales_b_per_stage + smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 smem_size += smem_barrier # Swizzle and padding are not compatible @@ -80,7 +78,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) - # Avoid bank conflicts for fp32 output + # Avoid bank conflicts for FP32 output if is_fp32_out: block_ns = [x for x in block_ns if x % 16 == 8] diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index cdd1714..8fb1a28 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -269,7 +269,7 @@ static void __instantiate_kernel() {{ return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) -class FP8WgradGemmRuntime(Runtime): +class FP8WGradGemmRuntime(Runtime): def __init__(self, path: str) -> None: super().__init__(path, [ 'NUM_TMA_MULTICAST', @@ -320,7 +320,7 @@ static void __instantiate_kernel() {{ }}; ''' if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Generated FP8 Wgrad GEMM code:\n{code}') + print(f'Generated FP8 WGrad GEMM code:\n{code}') return code # noinspection PyMethodOverriding diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 4fad99a..fbefb7b 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -1,9 +1,8 @@ -import math import torch from typing import List, Tuple from .runtime import ( - FP8WgradGemmRuntime, GemmType, + FP8WGradGemmRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc) from .gemm import get_best_configs @@ -122,7 +121,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), kwargs=kwargs, - runtime_cls=FP8WgradGemmRuntime, + runtime_cls=FP8WGradGemmRuntime, ) # Run the kernel diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py index 9068de1..55a9aff 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/utils.py @@ -78,7 +78,8 @@ class suppress_stdout_stderr: def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, - trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, is_multiple: bool = False): + trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, + with_multiple_kernels: bool = False): # Conflict with Nsight Systems using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) @@ -119,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) - if not is_multiple: + if not with_multiple_kernels: for name in kernel_names: assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' @@ -131,28 +132,18 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: units = {'ms': 1e3, 'us': 1e6} kernel_times = [] for name in kernel_names: - if not is_multiple: - for line in prof_lines: - if name in line: - time_str = line.split()[-2] - for unit, scale in units.items(): - if unit in time_str: - kernel_times.append(float(time_str.replace(unit, '')) / scale) - break - break - else: - total_time = 0 - total_num = 0 - for line in prof_lines: - if name in line: - time_str = line.split()[-2] - num_str = line.split()[-1] - for unit, scale in units.items(): - if unit in time_str: - total_time += float(time_str.replace(unit, '')) / scale * int(num_str) - total_num += int(num_str) - break - kernel_times.append(total_time / total_num) + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num) return tuple(kernel_times) if is_tupled else kernel_times[0] diff --git a/tests/test_core.py b/tests/test_core.py index c45b511..36c1c34 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -291,7 +291,7 @@ def test_k_grouped_wgrad_gemm(): def test_func(): deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) - t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, is_multiple=True) * num_groups + t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s')