mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
fix some bug
This commit is contained in:
parent
d29b20cd16
commit
26a603f518
@ -851,18 +851,6 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
|
|||||||
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
|
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if constexpr (SchedulerType::gemm_type == GemmType::StridedBatched)
|
|
||||||
{
|
|
||||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
|
||||||
__nv_bfloat16* gmem_d_this_block;
|
|
||||||
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
|
|
||||||
gmem_d_this_block = gmem_d + scheduler.curr_group_idx * problem_input.stride_d
|
|
||||||
+ (m_block_idx * BLOCK_M) * problem_input.ld_d;
|
|
||||||
constexpr int NUM_WARPS
|
|
||||||
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
|
|
||||||
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
|
|
||||||
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, problem_input.ld_d);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
cute::tma_store_fence();
|
cute::tma_store_fence();
|
||||||
|
|||||||
@ -391,21 +391,9 @@ struct SchedulerSelector
|
|||||||
{
|
{
|
||||||
static constexpr auto select_type()
|
static constexpr auto select_type()
|
||||||
{
|
{
|
||||||
if constexpr (GT == GemmType::Normal)
|
|
||||||
return NormalScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
|
|
||||||
kNumNBlocksPerGroup>();
|
|
||||||
if constexpr (GT == GemmType::GroupedContiguous)
|
|
||||||
return GroupedContiguousScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
|
|
||||||
kNumNBlocksPerGroup>();
|
|
||||||
if constexpr (GT == GemmType::GroupedMasked)
|
|
||||||
return GroupedMaskedScheduler<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumTMAMulticast,
|
|
||||||
kNumNBlocks, kNumNBlocksPerGroup>();
|
|
||||||
if constexpr (GT == GemmType::GroupedWithOffset)
|
if constexpr (GT == GemmType::GroupedWithOffset)
|
||||||
return GroupedWithOffsetScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
|
return GroupedWithOffsetScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
|
||||||
kNumNBlocksPerGroup>();
|
kNumNBlocksPerGroup>();
|
||||||
if constexpr (GT == GemmType::StridedBatched)
|
|
||||||
return StridedBatchedScheduler<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K, kNumGroups, kNumTMAMulticast,
|
|
||||||
kNumNBlocks, kNumNBlocksPerGroup>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
using type = decltype(select_type());
|
using type = decltype(select_type());
|
||||||
|
|||||||
@ -104,8 +104,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
|
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
|
||||||
else:
|
else:
|
||||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||||
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
#block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||||
|
block_ns = tuple(range(16, 129, 8))
|
||||||
# Avoid bank conflicts for FP32 output
|
# Avoid bank conflicts for FP32 output
|
||||||
if is_fp32_out:
|
if is_fp32_out:
|
||||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||||
|
|||||||
@ -260,7 +260,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
else:
|
else:
|
||||||
m_per_expert_threshold = 32 # H100
|
m_per_expert_threshold = 32 # H100
|
||||||
|
|
||||||
if expected_m>= m_per_expert_threshold:
|
if expected_m> m_per_expert_threshold:
|
||||||
|
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||||
expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False)
|
expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False)
|
||||||
|
|||||||
@ -119,10 +119,9 @@ def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int,
|
|||||||
def test_m_grouped_gemm_offset() -> None:
|
def test_m_grouped_gemm_offset() -> None:
|
||||||
print('Testing grouped contiguous GEMM:')
|
print('Testing grouped contiguous GEMM:')
|
||||||
|
|
||||||
for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),):
|
for num_groups, expected_m_per_group, k, n in ((9, 32, 7168, 4096),):
|
||||||
# NOTES: we should mask the unfilled part before calculating difference
|
# NOTES: we should mask the unfilled part before calculating difference
|
||||||
|
|
||||||
'''
|
|
||||||
x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n)
|
x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n)
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group)
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group)
|
||||||
|
|
||||||
@ -160,6 +159,8 @@ def test_m_grouped_gemm_offset() -> None:
|
|||||||
print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
||||||
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
|
||||||
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
|
||||||
|
|
||||||
|
'''
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user