Minor fixes

This commit is contained in:
Chenggang Zhao
2025-04-23 13:32:19 +08:00
parent 55d1d01c43
commit f4b205bfa3

View File

@@ -226,8 +226,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->arrive();
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
}, 0);
}
@@ -249,9 +249,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Preload TMA multicast validity, encouraged to use unified registers
bool is_tma_multicast_valid = __shfl_sync(0xffffffff, scheduler.is_tma_multicast_valid(m_block_idx), 0);
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
@@ -279,8 +276,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if (kNumTMAMulticast == 1 or not is_tma_multicast_valid) {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive() : void();
if (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
}