mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
fix some bug
This commit is contained in:
@@ -104,8 +104,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
|
||||
#block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
# Avoid bank conflicts for FP32 output
|
||||
if is_fp32_out:
|
||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||
|
||||
@@ -260,7 +260,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
else:
|
||||
m_per_expert_threshold = 32 # H100
|
||||
|
||||
if expected_m>= m_per_expert_threshold:
|
||||
if expected_m> m_per_expert_threshold:
|
||||
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False)
|
||||
|
||||
Reference in New Issue
Block a user