mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
use fa'3 transv
This commit is contained in:
parent
0337732dc1
commit
061af5fc56
@ -5,10 +5,11 @@ struct SmemTransposeFp8_64x64 {
|
|||||||
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
|
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
|
||||||
|
|
||||||
using Element = cutlass::float_e4m3_t;
|
using Element = cutlass::float_e4m3_t;
|
||||||
using SmemLayoutV = decltype(composition(
|
|
||||||
SmemLayoutK{},
|
|
||||||
Layout<Shape<Int<kBlockN>, Int<kHeadDim>>, Stride<_1, Int<kBlockN>>>{}));
|
|
||||||
using TransposeShapeAtomV = Shape<_64, _64>;
|
using TransposeShapeAtomV = Shape<_64, _64>;
|
||||||
|
using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
|
||||||
|
using SmemLayoutV =
|
||||||
|
decltype(tile_to_shape(SmemLayoutAtomV{},
|
||||||
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||||
|
|
||||||
// for fp8 in-kernel transpose -- src layout
|
// for fp8 in-kernel transpose -- src layout
|
||||||
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
|
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
|
||||||
@ -18,15 +19,15 @@ struct SmemTransposeFp8_64x64 {
|
|||||||
|
|
||||||
// For fp8, this is the memory transpose.
|
// For fp8, this is the memory transpose.
|
||||||
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
|
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
|
||||||
using SmemLayoutVt = decltype(tile_to_shape(
|
using SmemLayoutVt =
|
||||||
SmemLayoutAtomVt{},
|
decltype(tile_to_shape(SmemLayoutAtomVt{},
|
||||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||||
|
|
||||||
// for fp8 in-kernel transpose -- dst layout
|
// for fp8 in-kernel transpose -- dst layout
|
||||||
using SmemLayoutVtTrans = decltype(composition(
|
using SmemLayoutVtTrans = decltype(composition(
|
||||||
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
|
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
|
||||||
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
|
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
|
||||||
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
|
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
|
||||||
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
|
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
|
||||||
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
|
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
|
||||||
|
|
||||||
@ -40,8 +41,8 @@ struct SmemTransposeFp8_64x64 {
|
|||||||
|
|
||||||
using stsm_thread_shape = Shape<_4, _1, _8, _4>;
|
using stsm_thread_shape = Shape<_4, _1, _8, _4>;
|
||||||
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
|
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
|
||||||
using stsm_value_shape = Shape<_4, _4, _1, _2>;
|
using stsm_value_shape = Shape<_4, _4, _2, _1>;
|
||||||
using stsm_value_stride = Stride<_1, _8, _0, _4>;
|
using stsm_value_stride = Stride<_1, _8, _4, _0>;
|
||||||
|
|
||||||
using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<stsm_thread_shape>{},
|
using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<stsm_thread_shape>{},
|
||||||
Layout<stsm_value_shape, stsm_value_stride>{}));
|
Layout<stsm_value_shape, stsm_value_stride>{}));
|
||||||
@ -51,7 +52,7 @@ struct SmemTransposeFp8_64x64 {
|
|||||||
CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) {
|
CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
auto tid = threadIdx.x;
|
auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||||
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
|
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
|
||||||
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
|
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
|||||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||||
amax_diff = (x - y).abs().max().item()
|
amax_diff = (x - y).abs().max().item()
|
||||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||||
#assert cos_diff < 1e-5
|
assert cos_diff < 1e-5
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
Loading…
Reference in New Issue
Block a user