diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h index f1a1919..c566066 100644 --- a/csrc/fp8_transpose_v.h +++ b/csrc/fp8_transpose_v.h @@ -5,10 +5,11 @@ struct SmemTransposeFp8_64x64 { static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); using Element = cutlass::float_e4m3_t; - using SmemLayoutV = decltype(composition( - SmemLayoutK{}, - Layout, Int>, Stride<_1, Int>>{})); using TransposeShapeAtomV = Shape<_64, _64>; + using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + Shape, Int>{})); // for fp8 in-kernel transpose -- src layout using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); @@ -18,15 +19,15 @@ struct SmemTransposeFp8_64x64 { // For fp8, this is the memory transpose. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); - using SmemLayoutVt = decltype(tile_to_shape( - SmemLayoutAtomVt{}, - Shape, Int>{})); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + Shape, Int>{})); // for fp8 in-kernel transpose -- dst layout using SmemLayoutVtTrans = decltype(composition( SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); - using SmemShapeSTSM = Shape, Shape<_8, _8>>; + using SmemShapeSTSM = Shape, Shape<_16, _4>>; using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); @@ -40,8 +41,8 @@ struct SmemTransposeFp8_64x64 { using stsm_thread_shape = Shape<_4, _1, _8, _4>; // using stsm_thread_stride = Stride<_1, _0, _4, _32>; - using stsm_value_shape = Shape<_4, _4, _1, _2>; - using stsm_value_stride = Stride<_1, _8, _0, _4>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, Layout{})); @@ -51,7 +52,7 @@ struct SmemTransposeFp8_64x64 { CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { using namespace cute; - auto tid = threadIdx.x; + auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup; auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); @@ -64,11 +65,11 @@ struct SmemTransposeFp8_64x64 { auto data = tXrX.data(); CUTLASS_PRAGMA_UNROLL for (int n = 0; n < size(tXrX); n += 8) { - uint32_t *data_32bit = reinterpret_cast(&data[n]); - auto upper = data_32bit[0]; - auto lower = data_32bit[1]; - data_32bit[0] = __byte_perm(upper, lower, 0x6420); - data_32bit[1] = __byte_perm(upper, lower, 0x7531); + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); } cute::copy(tiled_copy_stsm, tXrX, tXsX_out); diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 91bd6f1..6cfd466 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -33,7 +33,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - #assert cos_diff < 1e-5 + assert cos_diff < 1e-5 @torch.inference_mode()