mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-22 16:34:22 +00:00
Simplify code
This commit is contained in:
parent
046fab64b7
commit
612dd57001
@ -308,20 +308,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
if constexpr (not kMustUseUniformedScaleB)
|
if constexpr (not kMustUseUniformedScaleB)
|
||||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||||
|
|
||||||
constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kNumFormerIters; ++ i) {
|
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||||
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
|
bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value;
|
||||||
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
|
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||||
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
|
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||||
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
|
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||||
}
|
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||||
#pragma unroll
|
|
||||||
for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) {
|
|
||||||
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
|
|
||||||
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
|
|
||||||
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
|
|
||||||
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user