#pragma once #include "configs.cuh" #include "exception.cuh" namespace deep_ep { template struct Buffer { private: uint8_t* ptr; public: int total_bytes; __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { total_bytes = num_elems * sizeof(dtype_t); ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) { gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; return *this; } __device__ __forceinline__ dtype_t* buffer() { return reinterpret_cast(ptr); } __device__ __forceinline__ dtype_t& operator[](int idx) { return buffer()[idx]; } }; template struct AsymBuffer { private: uint8_t* ptrs[kNumRanks]; int num_bytes; public: int total_bytes; __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks == 1, ""); num_bytes = num_elems * sizeof(dtype_t); int per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks > 1, ""); num_bytes = num_elems * sizeof(dtype_t); int per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; for (int i = 0; i < kNumRanks; ++ i) { ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; } } __device__ __forceinline__ void advance(int shift) { #pragma unroll for (int i = 0; i < kNumRanks; ++ i) ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); } __device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) { gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; return *this; } template __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { for (int i = 0; i < kNumAlsoRanks; ++ i) gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; return *this; } __device__ __forceinline__ dtype_t* buffer(int idx = 0) { EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[0] + num_bytes * idx); } __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); } }; template struct SymBuffer { private: // NOTES: for non-decoupled case, `recv_ptr` is not used uint8_t* send_ptr; uint8_t* recv_ptr; int num_bytes; public: int total_bytes; __device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { num_bytes = num_elems * sizeof(dtype_t); int per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); return reinterpret_cast(send_ptr + num_bytes * idx); } __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); return reinterpret_cast(recv_ptr + num_bytes * idx); } __device__ __forceinline__ dtype_t* buffer(int idx = 0) { EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); return reinterpret_cast(send_ptr + num_bytes * idx); } }; } // namespace deep_ep