mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-06 04:04:22 +00:00
Optimize performance
This commit is contained in:
parent
59884211ea
commit
07ef809d82
@ -249,6 +249,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Preload TMA multicast validity, encouraged to use unified registers
|
||||
bool is_tma_multicast_valid = __shfl_sync(0xffffffff, scheduler.is_tma_multicast_valid(m_block_idx), 0);
|
||||
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
@ -276,7 +279,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if (kNumTMAMulticast == 1 or not scheduler.is_tma_multicast_valid(m_block_idx)) {
|
||||
if (kNumTMAMulticast == 1 or not is_tma_multicast_valid) {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
|
Loading…
Reference in New Issue
Block a user