use fa'3 transv

This commit is contained in:
chenhongmin.will 2025-02-28 14:35:07 +08:00
parent 0337732dc1
commit 061af5fc56
2 changed files with 17 additions and 16 deletions

View File

@ -5,10 +5,11 @@ struct SmemTransposeFp8_64x64 {
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
using Element = cutlass::float_e4m3_t;
using SmemLayoutV = decltype(composition(
SmemLayoutK{},
Layout<Shape<Int<kBlockN>, Int<kHeadDim>>, Stride<_1, Int<kBlockN>>>{}));
using TransposeShapeAtomV = Shape<_64, _64>;
using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// for fp8 in-kernel transpose -- src layout
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
@ -18,15 +19,15 @@ struct SmemTransposeFp8_64x64 {
// For fp8, this is the memory transpose.
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
using SmemLayoutVt = decltype(tile_to_shape(
SmemLayoutAtomVt{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
using SmemLayoutVt =
decltype(tile_to_shape(SmemLayoutAtomVt{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// for fp8 in-kernel transpose -- dst layout
using SmemLayoutVtTrans = decltype(composition(
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
@ -40,8 +41,8 @@ struct SmemTransposeFp8_64x64 {
using stsm_thread_shape = Shape<_4, _1, _8, _4>;
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
using stsm_value_shape = Shape<_4, _4, _1, _2>;
using stsm_value_stride = Stride<_1, _8, _0, _4>;
using stsm_value_shape = Shape<_4, _4, _2, _1>;
using stsm_value_stride = Stride<_1, _8, _4, _0>;
using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<stsm_thread_shape>{},
Layout<stsm_value_shape, stsm_value_stride>{}));
@ -51,7 +52,7 @@ struct SmemTransposeFp8_64x64 {
CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) {
using namespace cute;
auto tid = threadIdx.x;
auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
@ -64,11 +65,11 @@ struct SmemTransposeFp8_64x64 {
auto data = tXrX.data();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size(tXrX); n += 8) {
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
auto upper = data_32bit[0];
auto lower = data_32bit[1];
data_32bit[0] = __byte_perm(upper, lower, 0x6420);
data_32bit[1] = __byte_perm(upper, lower, 0x7531);
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
auto upper = data_32bit[0];
auto lower = data_32bit[1];
data_32bit[0] = __byte_perm(upper, lower, 0x6420);
data_32bit[1] = __byte_perm(upper, lower, 0x7531);
}
cute::copy(tiled_copy_stsm, tXrX, tXsX_out);

View File

@ -33,7 +33,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
#assert cos_diff < 1e-5
assert cos_diff < 1e-5
@torch.inference_mode()