Check the empty list

This commit is contained in:
Chenggang Zhao 2025-06-11 11:14:30 +08:00
parent a8299ca7c2
commit dd13c7145c
2 changed files with 4 additions and 2 deletions

View File

@ -290,7 +290,8 @@ class Buffer:
recv_topk_idx: received expert indices.
recv_topk_weights: received expert weights.
num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by
each local expert, aligned to the input `expert_alignment`.
each local expert, aligned to the input `expert_alignment`. If `num_worst_tokens` is specified, the list
will be empty.
handle: the returned communication handle.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""

View File

@ -117,9 +117,10 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
if with_topk:
num_worst_tokens = num_tokens * num_ranks
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, _, _, event = buffer.dispatch(**dispatch_args)
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
assert len(empty_list) == 0
assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0)
assert num_worst_tokens == recv_worst_topk_weights.size(0)