mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Improve EP2/4 performance
This commit is contained in:
@@ -13,7 +13,6 @@ import test_low_latency
|
||||
|
||||
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
|
||||
# Settings
|
||||
# TODO: fix EP2/4/8 performance
|
||||
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
|
||||
assert num_experts % num_ranks == 0
|
||||
if local_rank == 0:
|
||||
@@ -182,7 +181,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 5, 1):
|
||||
for nvl_chunk_size in range(1, 7, 1):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
|
||||
Reference in New Issue
Block a user