mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-08 20:34:01 +00:00
Simplify code
This commit is contained in:
parent
046fab64b7
commit
612dd57001
@ -308,20 +308,13 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
|
||||
constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumFormerIters; ++ i) {
|
||||
final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3];
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
bool predicate = kMustUseUniformedScaleB or i < decltype(num_former_iters_type)::value;
|
||||
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user