mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-31 18:48:16 +00:00
Add extra TMA checks
This commit is contained in:
parent
ca13ce0fab
commit
6da94d2d36
@ -160,6 +160,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||||
|
|
||||||
|
# Extra checks for TMA store
|
||||||
|
if num_groups > 1 and m > block_m:
|
||||||
|
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
||||||
|
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||||
masked_m, m,
|
masked_m, m,
|
||||||
torch.cuda.current_stream(), num_sms, smem_size)
|
torch.cuda.current_stream(), num_sms, smem_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user