Use TMA instead of LD/ST for intra-node normal kernels (#191)

* Update CMake files

* Use TMA instead of LD/ST for intranode dispatch

* Use TMA instead of LD/ST for intranode combine

* Adjust configs

* Test default configs as well

* More warps for combine

* Add inter-thread fence

* Enable more warps

* Do not use TMA for senders

* Update configs

* Remove useless wait
This commit is contained in:
Chenggang Zhao
2025-06-06 15:40:17 +08:00
committed by GitHub
parent df4debe30c
commit c8dceba110
6 changed files with 230 additions and 87 deletions

View File

@@ -266,6 +266,67 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
}
__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" :: );
}
__device__ __forceinline__ void fence_barrier_init() {
asm volatile("fence.mbarrier_init.release.cluster; \n" :: );
}
__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680));
phase ^= 1;
}
__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void tma_store_fence() {
asm volatile ("fence.proxy.async.shared::cta;");
}
constexpr uint64_t kEvictFirst = 0x12f0000000000000;
constexpr uint64_t kEvictNormal = 0x1000000000000000;
__device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes,
bool evict_first = true) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n"
:: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory");
}
__device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes,
bool evict_first = true) {
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n"
:: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory");
asm volatile("cp.async.bulk.commit_group;");
}
template <int N = 0>
__device__ __forceinline__ void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
}
template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;