fix some bug

This commit is contained in:
Wangzheee 2025-06-20 06:53:24 +00:00
parent d29b20cd16
commit 26a603f518
5 changed files with 6 additions and 29 deletions

View File

@ -851,18 +851,6 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
}
}
else if constexpr (SchedulerType::gemm_type == GemmType::StridedBatched)
{
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
__nv_bfloat16* gmem_d_this_block;
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
gmem_d_this_block = gmem_d + scheduler.curr_group_idx * problem_input.stride_d
+ (m_block_idx * BLOCK_M) * problem_input.ld_d;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, problem_input.ld_d);
}
else
{
cute::tma_store_fence();

View File

@ -391,21 +391,9 @@ struct SchedulerSelector
{
static constexpr auto select_type()
{
if constexpr (GT == GemmType::Normal)
return NormalScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedContiguous)
return GroupedContiguousScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedMasked)
return GroupedMaskedScheduler<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumTMAMulticast,
kNumNBlocks, kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedWithOffset)
return GroupedWithOffsetScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
kNumNBlocksPerGroup>();
if constexpr (GT == GemmType::StridedBatched)
return StridedBatchedScheduler<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumTMAMulticast,
kNumNBlocks, kNumNBlocksPerGroup>();
}
using type = decltype(select_type());

View File

@ -104,8 +104,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
#block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
block_ns = tuple(range(16, 129, 8))
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]

View File

@ -260,7 +260,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
else:
m_per_expert_threshold = 32 # H100
if expected_m>= m_per_expert_threshold:
if expected_m> m_per_expert_threshold:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False)

View File

@ -119,10 +119,9 @@ def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int,
def test_m_grouped_gemm_offset() -> None:
print('Testing grouped contiguous GEMM:')
for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),):
for num_groups, expected_m_per_group, k, n in ((9, 32, 7168, 4096),):
# NOTES: we should mask the unfilled part before calculating difference
'''
x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group)
@ -160,6 +159,8 @@ def test_m_grouped_gemm_offset() -> None:
print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
'''
print()