Fully remove forwarders' and NVL receivers' code

This commit is contained in:
Chenggang Zhao
2025-06-19 13:48:07 +08:00
parent 3a3398f686
commit a0a6e22eff
3 changed files with 78 additions and 309 deletions

View File

@@ -2,6 +2,7 @@
#include "configs.cuh"
#include "exception.cuh"
#include "utils.cuh"
namespace deep_ep {
@@ -45,25 +46,26 @@ public:
int total_bytes;
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
int channel_id = 0, int num_channels = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
total_bytes = per_channel_bytes * num_channels;
ptrs[0] = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * channel_id + num_bytes * offset;
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
int channel_id = 0, int num_channels = 1, int offset = 0) {
// TODO: use UR as much as possible
EP_STATIC_ASSERT(kNumRanks > 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
total_bytes = per_channel_bytes * num_channels;
for (int i = 0; i < kNumRanks; ++ i) {
ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * channel_id + num_bytes * offset;
gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
}
}
@@ -86,15 +88,22 @@ public:
return *this;
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
__device__ __forceinline__ dtype_t* buffer(const int& idx = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, const int& idx = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer_by_sync(int rank_idx, const int& idx = 0) {
// Different lanes store different pointers
// NOTES: this function requires the whole warp
EP_STATIC_ASSERT(kNumRanks == 1, "Invalid number of ranks");
return broadcast(reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx), rank_idx);
}
};
template <typename dtype_t, bool kDecoupled = true>