mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-10 18:55:28 +00:00
refine
This commit is contained in:
parent
6e53c6613d
commit
094d0421ec
@ -179,8 +179,6 @@ fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_b[s], n_block_idx * BLOCK_N, scheduler.get_global_idx(0, 1, k_idx / BLOCK_K));
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
// full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
@ -268,7 +266,7 @@ fp8_gemm_bw_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // WGMMA::kNumAccum = 64, loop 16 steps
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
int src_lane_id = threadIdx.x % 4 + (i % 8) * 4;
|
||||
float scale_b_0 = __shfl_sync(0xffffffff, scale_b[i/8].x, src_lane_id);
|
||||
float scale_b_1 = __shfl_sync(0xffffffff, scale_b[i/8].y, src_lane_id);
|
||||
|
@ -75,7 +75,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(32, 129, 32))
|
||||
block_ns = (128, )
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
|
@ -109,7 +109,6 @@ def test_gemm_backward_w() -> None:
|
||||
diff = calc_diff(out, ref_out)
|
||||
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
||||
torch.cuda.synchronize()
|
||||
print(diff)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
@ -195,5 +194,5 @@ if __name__ == '__main__':
|
||||
|
||||
test_gemm_backward_w()
|
||||
test_gemm()
|
||||
# test_m_grouped_gemm_contiguous()
|
||||
# test_m_grouped_gemm_masked()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
|
Loading…
Reference in New Issue
Block a user