diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 439d17d..ff4634c 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -290,7 +290,8 @@ class Buffer: recv_topk_idx: received expert indices. recv_topk_weights: received expert weights. num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by - each local expert, aligned to the input `expert_alignment`. + each local expert, aligned to the input `expert_alignment`. If `num_worst_tokens` is specified, the list + will be empty. handle: the returned communication handle. event: the event after executing the kernel (valid only if `async_finish` is set). """ diff --git a/tests/test_intranode.py b/tests/test_intranode.py index c59dc46..fb8a573 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -117,9 +117,10 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: if with_topk: num_worst_tokens = num_tokens * num_ranks dispatch_args.update({'num_worst_tokens': num_worst_tokens}) - recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, _, _, event = buffer.dispatch(**dispatch_args) + recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args) event.current_stream_wait() if async_mode else () recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x + assert len(empty_list) == 0 assert num_worst_tokens == recv_worst_x.size(0) assert num_worst_tokens == recv_worst_topk_idx.size(0) assert num_worst_tokens == recv_worst_topk_weights.size(0)