This commit is contained in:
kavioyu 2025-03-13 07:04:56 +00:00
parent 6e53c6613d
commit 094d0421ec
3 changed files with 4 additions and 7 deletions

View File

@ -179,8 +179,6 @@ fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_b[s], n_block_idx * BLOCK_N, scheduler.get_global_idx(0, 1, k_idx / BLOCK_K));
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
// full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
// Wait unaligned cases
@ -268,7 +266,7 @@ fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
empty_barrier_arrive(s);
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // WGMMA::kNumAccum = 64, loop 16 steps
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
int src_lane_id = threadIdx.x % 4 + (i % 8) * 4;
float scale_b_0 = __shfl_sync(0xffffffff, scale_b[i/8].x, src_lane_id);
float scale_b_1 = __shfl_sync(0xffffffff, scale_b[i/8].y, src_lane_id);

View File

@ -75,7 +75,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (64 if m <= 64 else 128, )
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(32, 129, 32))
block_ns = (128, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)

View File

@ -109,7 +109,6 @@ def test_gemm_backward_w() -> None:
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
torch.cuda.synchronize()
print(diff)
# noinspection PyShadowingNames
def test_func():
@ -195,5 +194,5 @@ if __name__ == '__main__':
test_gemm_backward_w()
test_gemm()
# test_m_grouped_gemm_contiguous()
# test_m_grouped_gemm_masked()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()