Make index access granular

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-22 13:44:29 +00:00
parent f7b8992020
commit b679f66e58
3 changed files with 52 additions and 41 deletions

View File

@ -295,12 +295,11 @@ def _test_async_transfer_layer_without_mtp_worker(
for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer(
old_global_expert_indices=old_indices_cpu,
new_global_expert_indices=new_indices_cpu,
expert_weights=expert_weights,
old_layer_indices=old_indices_cpu[layer_idx],
new_layer_indices=new_indices_cpu[layer_idx],
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
ep_group=ep_group,
layer=layer_idx,
cuda_stream=cuda_stream,
)
)

View File

@ -109,8 +109,7 @@ async def transfer_run_periodically(
for model_state in state.model_states.values():
if not model_state.is_async_enabled:
continue
# Rebalance experts is done once, only when the async worker wakes up.
run_rebalance_experts(model_state, state)
rebalancing_algorithm_executed = False
logger.info(
"Async worker computed new indices for model %s",
model_state.model_name,
@ -121,28 +120,34 @@ async def transfer_run_periodically(
model_state.rebalanced
and model_state.layer_to_transfer < current_num_layers
):
if (
not model_state.ep_buffer_ready
and model_state.rebalanced
and model_state.new_physical_to_logical_map is not None
):
if not model_state.ep_buffer_ready and model_state.rebalanced:
await asyncio.to_thread(model_state.buffer_lock.acquire)
try:
if model_state.layer_to_transfer >= current_num_layers:
break
if not rebalancing_algorithm_executed:
run_rebalance_experts(model_state, state)
rebalancing_algorithm_executed = True
layer_idx = model_state.layer_to_transfer
old_layer_indices = model_state.old_physical_to_logical_map[
layer_idx
]
new_layer_indices = model_state.new_physical_to_logical_map[
layer_idx
]
(
model_state.is_unchanged,
model_state.is_received_locally,
model_state.recv_metadata,
) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map,
new_global_expert_indices=model_state.new_physical_to_logical_map,
expert_weights=model_state.model.expert_weights,
old_layer_indices=old_layer_indices,
new_layer_indices=new_layer_indices,
expert_weights=model_state.model.expert_weights[layer_idx],
expert_weights_buffer=model_state.expert_buffer,
ep_group=ep_group,
is_profile=is_profile,
layer=model_state.layer_to_transfer,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)

View File

@ -434,13 +434,12 @@ def move_from_buffer(
async def transfer_layer(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Iterable[torch.Tensor]],
old_layer_indices: torch.Tensor,
new_layer_indices: torch.Tensor,
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup,
is_profile: bool = False,
layer: int = 0,
cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None,
) -> MoveToBufferResult:
@ -451,55 +450,63 @@ async def transfer_layer(
while keys are physical.
Args:
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different.
old_layer_indices: Shape (num_physical_experts,).
new_layer_indices: Shape (num_physical_experts,).
expert_weights: Iterable of weight tensors for this layer, each with shape
(num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection.
expert_weights_buffer: Intermediate buffers (one per weight tensor).
ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns:
is_unchanged (np.ndarray): (1, num_local_experts), True where expert
is_unchanged (np.ndarray): (num_local_experts,), True where expert
is left unchanged.
is_received_locally (np.ndarray): (1, num_local_experts), True where expert
is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
ep_size = ep_group.size()
if rank_mapping is not None:
# Add a layer dimension for compatibility with mapping functions
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
if len(rank_mapping) == ep_group.size():
# scale down
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices,
new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
new_layer_indices_2d,
rank_mapping,
)
else:
# scale up
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices,
old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
old_layer_indices_2d,
rank_mapping,
ep_group.size(),
)
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
# Remove the layer dimension
old_layer_indices = old_layer_indices_2d.squeeze(0)
new_layer_indices = new_layer_indices_2d.squeeze(0)
assert old_layer_indices.shape == new_layer_indices.shape
num_physical_experts = old_layer_indices.shape[0]
num_local_physical_experts = next(iter(expert_weights)).shape[0]
assert num_physical_experts == ep_size * num_local_physical_experts
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy()
old_layer_indices_np = old_layer_indices.cpu().numpy()
new_layer_indices_np = new_layer_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_np[layer],
new_indices=new_global_expert_indices_np[layer],
expert_weights=expert_weights[layer],
old_indices=old_layer_indices_np,
new_indices=new_layer_indices_np,
expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream,
ep_group=ep_group,