mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Performance: BlockTile 256x128 optimizations enable 1500+ TFLOPS FP8 performance on the H800-SXM platform
This commit is contained in:
parent
b4ecf9c3ff
commit
97575bf1c6
@ -257,7 +257,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum*2] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
@ -306,9 +306,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
@ -325,6 +322,48 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
if constexpr (BLOCK_M == 256) {
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_2 = ld_shared(smem_scales_a[s] + r_0 + 2 * WGMMA::M), scale_a_3 = ld_shared(smem_scales_a[s] + r_1 + 2 * WGMMA::M);
|
||||
// Promote with scales
|
||||
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
||||
float scale_2_0 = scale_a_2 * scale_b_0, scale_3_0 = scale_a_3 * scale_b_0;
|
||||
float scale_2_1, scale_3_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_2_1 = scale_a_2 * scale_b_1, scale_3_1 = scale_a_3 * scale_b_1;
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K + 2 * WGMMA::M * BLOCK_K , 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// #pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||
bool predicate = kMustUseUniformedScaleB or (i + WGMMA::kNumAccum / 4) < num_former_iters;
|
||||
final_accum[i * 4 + 0 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1 + WGMMA::kNumAccum] += (predicate ? scale_2_0 : scale_2_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3 + WGMMA::kNumAccum] += (predicate ? scale_3_0 : scale_3_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
@ -347,12 +386,34 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
|
||||
);
|
||||
}
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
if constexpr (BLOCK_M == 256) {
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
||||
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 0 + WGMMA::kNumAccum], final_accum[i * 8 + 1 + WGMMA::kNumAccum]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 2 + WGMMA::kNumAccum], final_accum[i * 8 + 3 + WGMMA::kNumAccum]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 4 + WGMMA::kNumAccum], final_accum[i * 8 + 5 + WGMMA::kNumAccum]}),
|
||||
__float22bfloat162_rn({final_accum[i * 8 + 6 + WGMMA::kNumAccum], final_accum[i * 8 + 7 + WGMMA::kNumAccum]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) + BLOCK_M / 2 * (BLOCK_N + BLOCK_N_PADDING)
|
||||
);
|
||||
}
|
||||
}
|
||||
if constexpr (BLOCK_M == 256) {
|
||||
if constexpr (WGMMA::kNumAccum * 2 % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 0], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 2], final_accum[WGMMA::kNumAccum * 2 / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum * 2 / 8 * 16
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
||||
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
|
||||
);
|
||||
}
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
@ -866,6 +866,94 @@ struct SM90_64x192x32_F32E4M3E4M3_SS {
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
struct SM90_64x256x32_F32E4M3E4M3_SS {
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
||||
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
||||
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
||||
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
||||
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
||||
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
||||
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
||||
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
||||
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
||||
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
||||
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
||||
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
||||
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
||||
float& d96, float& d97, float& d98, float& d99, float& d100,float& d101,float& d102,float& d103,
|
||||
float& d104,float& d105,float& d106,float& d107,float& d108,float& d109,float& d110,float& d111,
|
||||
float& d112,float& d113,float& d114,float& d115,float& d116,float& d117,float& d118,float& d119,
|
||||
float& d120,float& d121,float& d122,float& d123,float& d124,float& d125,float& d126,float& d127,
|
||||
bool scale_d) {
|
||||
asm volatile("{\n"
|
||||
".reg .pred p;\n"
|
||||
"setp.ne.b32 p, %130, 0;\n"
|
||||
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
||||
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
||||
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
||||
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
||||
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
||||
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
||||
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
||||
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
||||
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
||||
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
||||
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
||||
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
||||
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
||||
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
||||
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
||||
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
||||
" %128,"
|
||||
" %129,"
|
||||
" p , 1, 1;\n"
|
||||
"}\n"
|
||||
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
||||
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
||||
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
||||
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
||||
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
||||
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
||||
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
||||
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
||||
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
||||
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
||||
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
||||
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95),
|
||||
"+f"(d96), "+f"(d97), "+f"(d98), "+f"(d99), "+f"(d100),"+f"(d101),"+f"(d102),"+f"(d103),
|
||||
"+f"(d104),"+f"(d105),"+f"(d106),"+f"(d107),"+f"(d108),"+f"(d109),"+f"(d110),"+f"(d111),
|
||||
"+f"(d112),"+f"(d113),"+f"(d114),"+f"(d115),"+f"(d116),"+f"(d117),"+f"(d118),"+f"(d119),
|
||||
"+f"(d120),"+f"(d121),"+f"(d122),"+f"(d123),"+f"(d124),"+f"(d125),"+f"(d126),"+f"(d127)
|
||||
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
||||
}
|
||||
|
||||
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
wgmma(desc_a, desc_b,
|
||||
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
||||
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
||||
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
||||
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
||||
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
||||
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
||||
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
||||
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
||||
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
||||
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
||||
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
||||
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
||||
d[96], d[97], d[98], d[99], d[100],d[101],d[102],d[103],
|
||||
d[104],d[105],d[106],d[107],d[108],d[109],d[110],d[111],
|
||||
d[112],d[113],d[114],d[115],d[116],d[117],d[118],d[119],
|
||||
d[120],d[121],d[122],d[123],d[124],d[125],d[126],d[127],
|
||||
scale_d);
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = 256;
|
||||
static constexpr int K = 32;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
@ -1008,6 +1096,7 @@ struct FP8MMASelector {
|
||||
if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
||||
if constexpr (N == 256) return SM90_64x256x32_F32E4M3E4M3_SS();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
|
@ -74,10 +74,16 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
Tuple[int, int, int, int, Tuple[int, bool], int]:
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
# block_ms = (64 if m <= 64 else 128, )
|
||||
if m <= 64:
|
||||
block_ms = (64, )
|
||||
elif m <= 128:
|
||||
block_ms = (64, 128, )
|
||||
else:
|
||||
block_ms = (64, 128, 256, )
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
|
Loading…
Reference in New Issue
Block a user