mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support CUDA graph for intranode normal kernels (#203)
This commit is contained in:
parent
8da2d7b38d
commit
a8299ca7c2
@ -162,6 +162,7 @@ def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
allocate_on_comm_stream=previous_event is not None)
|
||||
# Do MoE dispatch
|
||||
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
|
||||
# Unless you specify `num_worst_tokens`, but this flag is for intranode only
|
||||
# For more advanced usages, please refer to the docs of the `dispatch` function
|
||||
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
|
||||
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
|
||||
|
||||
@ -284,7 +284,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
||||
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
||||
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
int expert_alignment, int num_worst_tokens, const Config& config,
|
||||
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
bool cached_mode = cached_rank_prefix_matrix.has_value();
|
||||
|
||||
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
@ -412,6 +413,14 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
|
||||
comm_stream, num_channels);
|
||||
|
||||
if (num_worst_tokens > 0) {
|
||||
// No CPU sync, just allocate the worst case
|
||||
num_recv_tokens = num_worst_tokens;
|
||||
|
||||
// Must be forward with top-k stuffs
|
||||
EP_HOST_ASSERT(topk_idx.has_value());
|
||||
EP_HOST_ASSERT(topk_weights.has_value());
|
||||
} else {
|
||||
// Synchronize total received tokens and tokens per expert
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
while (true) {
|
||||
@ -432,6 +441,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
}
|
||||
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate new tensors
|
||||
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
|
||||
@ -472,7 +482,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
send_head.data_ptr<int>(),
|
||||
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
|
||||
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
|
||||
num_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
||||
num_tokens, num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
|
||||
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
|
||||
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
|
||||
|
||||
|
||||
@ -108,7 +108,8 @@ public:
|
||||
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
||||
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
||||
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||
int expert_alignment, int num_worst_tokens, const Config& config,
|
||||
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||
|
||||
@ -45,7 +45,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
|
||||
@ -25,7 +25,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
|
||||
int *per_rank_buffer, *per_expert_buffer;
|
||||
if (thread_id < kNumRanks) {
|
||||
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
|
||||
per_rank_buffer = static_cast<int*>(buffer_ptrs[thread_id]);
|
||||
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
|
||||
// Sum per-rank counts and return to CPU
|
||||
// Also pre-compute the prefix sum for data sending
|
||||
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
auto local_per_rank_buffer = static_cast<int*>(buffer_ptrs[rank]);
|
||||
if (thread_id < kNumRanks) {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < kNumRanks; ++ i)
|
||||
@ -141,7 +141,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
|
||||
// Copy and clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
|
||||
ptr[i] = rank_prefix_matrix[i];
|
||||
@ -173,7 +173,7 @@ __global__ void __launch_bounds__(kNumThreads, 1)
|
||||
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
@ -196,7 +196,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
|
||||
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
|
||||
int target_rank = is_sender ? rank : responsible_rank;
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
|
||||
@ -286,7 +286,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
int chunk_token_idx = 0;
|
||||
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
||||
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
|
||||
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send the following data
|
||||
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
||||
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
|
||||
|
||||
@ -349,7 +349,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
|
||||
|
||||
// Calculate offset first
|
||||
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
|
||||
int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0;
|
||||
|
||||
// Receive channel offset
|
||||
@ -372,7 +372,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
auto start_time = clock64();
|
||||
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
||||
while (num_tokens_to_recv > 0) {
|
||||
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same
|
||||
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are the same
|
||||
while (recv_thread_id_in_rank == 0) {
|
||||
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
|
||||
|
||||
@ -450,12 +450,25 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
}
|
||||
|
||||
|
||||
// Clean unused `recv_topk_idx` as -1
|
||||
if (num_worst_tokens > 0) {
|
||||
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
|
||||
const auto num_recv_tokens = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank];
|
||||
const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads;
|
||||
const auto clean_end = num_worst_tokens * num_topk;
|
||||
const auto clean_stride = num_sms * kNumThreads;
|
||||
#pragma unroll
|
||||
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
|
||||
recv_topk_idx[i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
@ -470,7 +483,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
} break
|
||||
@ -493,7 +506,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||
ptr[i] = 0;
|
||||
@ -590,7 +603,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
||||
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
|
||||
|
||||
@ -682,7 +695,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
|
||||
|
||||
if (thread_id < 32) {
|
||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
|
||||
int* channel_head_idx_ptr = static_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
|
||||
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
|
||||
|
||||
// Queue head updater
|
||||
@ -720,13 +733,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + i;
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
|
||||
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
|
||||
|
||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||
|
||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
|
||||
ptr = reinterpret_cast<void*>(static_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
|
||||
|
||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
|
||||
@ -249,7 +249,8 @@ class Buffer:
|
||||
handle: Optional[Tuple] = None,
|
||||
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
||||
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None,
|
||||
expert_alignment: int = 1, num_worst_tokens: int = 0,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
@ -276,6 +277,8 @@ class Buffer:
|
||||
`-1` means no selections.
|
||||
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
|
||||
expert_alignment: align the number of tokens received by each local expert to this variable.
|
||||
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
|
||||
will be CUDA-graph compatible. Please also notice that this flag is for intranode only.
|
||||
config: the performance tuning config.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
@ -296,6 +299,7 @@ class Buffer:
|
||||
|
||||
# Internode
|
||||
if self.runtime.get_num_rdma_ranks() > 1:
|
||||
assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`'
|
||||
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||
|
||||
@ -308,14 +312,16 @@ class Buffer:
|
||||
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
|
||||
x, x_scales, None, None,
|
||||
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
expert_alignment, num_worst_tokens, config,
|
||||
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
||||
else:
|
||||
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
||||
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
|
||||
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
expert_alignment, num_worst_tokens, config,
|
||||
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
||||
|
||||
|
||||
@ -100,6 +100,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
recv_topk_weights_clone = None
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
|
||||
@ -107,10 +108,26 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
recv_topk_weights_clone = recv_topk_weights.clone()
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
|
||||
check_data(recv_topk_weights, rank_prefix_matrix)
|
||||
|
||||
# Test `num_worst_tokens != 0`
|
||||
if with_topk:
|
||||
num_worst_tokens = num_tokens * num_ranks
|
||||
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
|
||||
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
|
||||
assert num_worst_tokens == recv_worst_x.size(0)
|
||||
assert num_worst_tokens == recv_worst_topk_idx.size(0)
|
||||
assert num_worst_tokens == recv_worst_topk_weights.size(0)
|
||||
assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
|
||||
assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
|
||||
assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
|
||||
assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user