mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Use TMA instead of LD/ST for intra-node normal kernels (#191)
* Update CMake files * Use TMA instead of LD/ST for intranode dispatch * Use TMA instead of LD/ST for intranode combine * Adjust configs * Test default configs as well * More warps for combine * Add inter-thread fence * Enable more warps * Do not use TMA for senders * Update configs * Remove useless wait
This commit is contained in:
@@ -266,6 +266,67 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
|
||||
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void fence_view_async_shared() {
|
||||
asm volatile("fence.proxy.async.shared::cta; \n" :: );
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void fence_barrier_init() {
|
||||
asm volatile("fence.mbarrier_init.release.cluster; \n" :: );
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
asm volatile("{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"LAB_WAIT: \n\t"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
|
||||
"@P1 bra DONE; \n\t"
|
||||
"bra LAB_WAIT; \n\t"
|
||||
"DONE: \n\t"
|
||||
"}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680));
|
||||
phase ^= 1;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tma_store_fence() {
|
||||
asm volatile ("fence.proxy.async.shared::cta;");
|
||||
}
|
||||
|
||||
constexpr uint64_t kEvictFirst = 0x12f0000000000000;
|
||||
constexpr uint64_t kEvictNormal = 0x1000000000000000;
|
||||
|
||||
__device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes,
|
||||
bool evict_first = true) {
|
||||
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
|
||||
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
|
||||
asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n"
|
||||
:: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes,
|
||||
bool evict_first = true) {
|
||||
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
|
||||
asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n"
|
||||
:: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory");
|
||||
asm volatile("cp.async.bulk.commit_group;");
|
||||
}
|
||||
|
||||
template <int N = 0>
|
||||
__device__ __forceinline__ void tma_store_wait() {
|
||||
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
|
||||
return (a + b - 1) / b;
|
||||
|
||||
Reference in New Issue
Block a user