This commit is contained in:
Kalyan Kodela 2025-03-24 20:35:43 +08:00 committed by GitHub
commit 57919718be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

67
eplb.py
View File

@ -161,4 +161,71 @@ def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1))
return phy2log, log2phy, logcnt
def rebalance_with_migration_cost(
current_mapping: torch.Tensor,
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
migration_cost_factor: float = 0.5
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rebalance experts while considering the cost of migrating experts from their current placement.
This method extends the basic rebalance_experts function by adding a penalty for
moving experts from their current location, which is useful for dynamic systems
where the cost of migration needs to be balanced against load distribution benefits.
Parameters:
current_mapping: [layers, num_replicas], the current expert mapping
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes
num_gpus: number of GPUs, must be a multiple of `num_nodes`
migration_cost_factor: weight for the migration cost (0.0 to ignore migration costs)
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""
# First, get the ideal mapping without considering migration costs
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, num_groups, num_nodes, num_gpus)
# If migration cost factor is zero or no current mapping exists, return the ideal mapping
if migration_cost_factor == 0.0 or current_mapping is None:
return phy2log, log2phy, logcnt
num_layers, num_logical_experts = weight.shape
experts_per_gpu = num_replicas // num_gpus
# Adjust weights to account for migration costs
adjusted_weight = weight.clone()
for layer in range(num_layers):
# Create a mapping from logical expert to current physical placement
current_placements = {}
for phys_idx, log_idx in enumerate(current_mapping[layer]):
log_idx = log_idx.item()
gpu_idx = phys_idx // experts_per_gpu
if log_idx not in current_placements:
current_placements[log_idx] = []
current_placements[log_idx].append(gpu_idx)
# Adjust weights based on current placements
for log_idx in range(num_logical_experts):
# If the expert is currently not placed, no adjustment needed
if log_idx not in current_placements:
continue
# The adjustment increases the apparent weight of the expert on GPUs
# where it's already placed, making it more likely to stay there
migration_benefit = weight[layer, log_idx] * migration_cost_factor
adjusted_weight[layer, log_idx] += migration_benefit
# Use the adjusted weights to rebalance
return rebalance_experts(adjusted_weight, num_replicas, num_groups, num_nodes, num_gpus)
__all__ = ['rebalance_experts']