Make index access granular

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-22 13:44:29 +00:00
parent f7b8992020
commit b679f66e58
3 changed files with 52 additions and 41 deletions

View File

@ -295,12 +295,11 @@ def _test_async_transfer_layer_without_mtp_worker(
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run( is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer( transfer_layer(
old_global_expert_indices=old_indices_cpu, old_layer_indices=old_indices_cpu[layer_idx],
new_global_expert_indices=new_indices_cpu, new_layer_indices=new_indices_cpu[layer_idx],
expert_weights=expert_weights, expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer, expert_weights_buffer=expert_buffer,
ep_group=ep_group, ep_group=ep_group,
layer=layer_idx,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
) )
) )

View File

@ -109,8 +109,7 @@ async def transfer_run_periodically(
for model_state in state.model_states.values(): for model_state in state.model_states.values():
if not model_state.is_async_enabled: if not model_state.is_async_enabled:
continue continue
# Rebalance experts is done once, only when the async worker wakes up. rebalancing_algorithm_executed = False
run_rebalance_experts(model_state, state)
logger.info( logger.info(
"Async worker computed new indices for model %s", "Async worker computed new indices for model %s",
model_state.model_name, model_state.model_name,
@ -121,28 +120,34 @@ async def transfer_run_periodically(
model_state.rebalanced model_state.rebalanced
and model_state.layer_to_transfer < current_num_layers and model_state.layer_to_transfer < current_num_layers
): ):
if ( if not model_state.ep_buffer_ready and model_state.rebalanced:
not model_state.ep_buffer_ready
and model_state.rebalanced
and model_state.new_physical_to_logical_map is not None
):
await asyncio.to_thread(model_state.buffer_lock.acquire) await asyncio.to_thread(model_state.buffer_lock.acquire)
try: try:
if model_state.layer_to_transfer >= current_num_layers: if model_state.layer_to_transfer >= current_num_layers:
break break
if not rebalancing_algorithm_executed:
run_rebalance_experts(model_state, state)
rebalancing_algorithm_executed = True
layer_idx = model_state.layer_to_transfer
old_layer_indices = model_state.old_physical_to_logical_map[
layer_idx
]
new_layer_indices = model_state.new_physical_to_logical_map[
layer_idx
]
( (
model_state.is_unchanged, model_state.is_unchanged,
model_state.is_received_locally, model_state.is_received_locally,
model_state.recv_metadata, model_state.recv_metadata,
) = await transfer_layer( ) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map, old_layer_indices=old_layer_indices,
new_global_expert_indices=model_state.new_physical_to_logical_map, new_layer_indices=new_layer_indices,
expert_weights=model_state.model.expert_weights, expert_weights=model_state.model.expert_weights[layer_idx],
expert_weights_buffer=model_state.expert_buffer, expert_weights_buffer=model_state.expert_buffer,
ep_group=ep_group, ep_group=ep_group,
is_profile=is_profile, is_profile=is_profile,
layer=model_state.layer_to_transfer,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
) )

View File

@ -434,13 +434,12 @@ def move_from_buffer(
async def transfer_layer( async def transfer_layer(
old_global_expert_indices: torch.Tensor, old_layer_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor, new_layer_indices: torch.Tensor,
expert_weights: Sequence[Iterable[torch.Tensor]], expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor], expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup, ep_group: ProcessGroup,
is_profile: bool = False, is_profile: bool = False,
layer: int = 0,
cuda_stream: torch.cuda.Stream | None = None, cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> MoveToBufferResult: ) -> MoveToBufferResult:
@ -451,55 +450,63 @@ async def transfer_layer(
while keys are physical. while keys are physical.
Args: Args:
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). old_layer_indices: Shape (num_physical_experts,).
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). new_layer_indices: Shape (num_physical_experts,).
expert_weights: A sequence of shape (num_moe_layers)(weight_count) expert_weights: Iterable of weight tensors for this layer, each with shape
of tensors of shape (num_local_physical_experts, hidden_size_i). (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection, For example, a linear layer may have up and down projection.
so weight_count = 2. Each weight's hidden size can be different. expert_weights_buffer: Intermediate buffers (one per weight tensor).
ep_group: The device process group for expert parallelism. ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy. is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers. communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns: Returns:
is_unchanged (np.ndarray): (1, num_local_experts), True where expert is_unchanged (np.ndarray): (num_local_experts,), True where expert
is left unchanged. is left unchanged.
is_received_locally (np.ndarray): (1, num_local_experts), True where expert is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally. can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers. RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
ep_size = ep_group.size() ep_size = ep_group.size()
if rank_mapping is not None: if rank_mapping is not None:
# Add a layer dimension for compatibility with mapping functions
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
if len(rank_mapping) == ep_group.size(): if len(rank_mapping) == ep_group.size():
# scale down # scale down
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping( new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices, new_layer_indices_2d,
rank_mapping, rank_mapping,
) )
else: else:
# scale up # scale up
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping( old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices, old_layer_indices_2d,
rank_mapping, rank_mapping,
ep_group.size(), ep_group.size(),
) )
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1] # Remove the layer dimension
num_moe_layers, num_physical_experts = old_global_expert_indices.shape old_layer_indices = old_layer_indices_2d.squeeze(0)
assert len(expert_weights) == num_moe_layers new_layer_indices = new_layer_indices_2d.squeeze(0)
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert old_layer_indices.shape == new_layer_indices.shape
num_physical_experts = old_layer_indices.shape[0]
num_local_physical_experts = next(iter(expert_weights)).shape[0]
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy() old_layer_indices_np = old_layer_indices.cpu().numpy()
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy() new_layer_indices_np = new_layer_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer( is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_np[layer], old_indices=old_layer_indices_np,
new_indices=new_global_expert_indices_np[layer], new_indices=new_layer_indices_np,
expert_weights=expert_weights[layer], expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer, expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
ep_group=ep_group, ep_group=ep_group,