mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Several cleanups
This commit is contained in:
@@ -36,19 +36,20 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
|
||||
assert block_k == 128
|
||||
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
# NOTES: `scales_b` in a total manner or per-stage manner
|
||||
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
if is_wgrad:
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, 128) * 128
|
||||
else:
|
||||
smem_scales_b = ceil_div(k, block_k) * 4
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
|
||||
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
@@ -56,11 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
if is_wgrad:
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
else:
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k %
|
||||
block_n == 0 else 2), 8) * 8
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
@@ -80,7 +78,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
|
||||
# Avoid bank conflicts for fp32 output
|
||||
# Avoid bank conflicts for FP32 output
|
||||
if is_fp32_out:
|
||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||
|
||||
|
||||
@@ -269,7 +269,7 @@ static void __instantiate_kernel() {{
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
|
||||
|
||||
|
||||
class FP8WgradGemmRuntime(Runtime):
|
||||
class FP8WGradGemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
@@ -320,7 +320,7 @@ static void __instantiate_kernel() {{
|
||||
}};
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 Wgrad GEMM code:\n{code}')
|
||||
print(f'Generated FP8 WGrad GEMM code:\n{code}')
|
||||
return code
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
from .runtime import (
|
||||
FP8WgradGemmRuntime, GemmType,
|
||||
FP8WGradGemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
|
||||
from .gemm import get_best_configs
|
||||
@@ -122,7 +121,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8WgradGemmRuntime,
|
||||
runtime_cls=FP8WGradGemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
|
||||
Reference in New Issue
Block a user