mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 13:57:06 +08:00
Updates after review
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
cfac6b3f64
commit
6b2a1de500
@ -47,7 +47,11 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
|
||||
from .async_worker import start_async_worker
|
||||
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
|
||||
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
|
||||
from .rebalance_execute import (
|
||||
RecvMetadata,
|
||||
move_from_buffer,
|
||||
rearrange_expert_weights_inplace,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -175,14 +179,9 @@ class EplbModelState:
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
The size is same as the num of physical experts in the current layer.
|
||||
"""
|
||||
recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||
recv_metadata: RecvMetadata
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
The tuple contains:
|
||||
- recv_primary_mask: np.ndarray, shape (group_size, num_local_experts)
|
||||
- recv_counts: np.ndarray, shape (group_size,)
|
||||
- recv_expert_ids: np.ndarray, shape (group_size, num_local_experts)
|
||||
- recv_dst_rows: np.ndarray, shape (group_size, num_local_experts)
|
||||
"""
|
||||
is_async_enabled: bool
|
||||
"""
|
||||
@ -514,7 +513,12 @@ class EplbState:
|
||||
pending_global_ready_check=False,
|
||||
is_unchanged=np.array([]),
|
||||
is_received_locally=np.array([]),
|
||||
recv_metadata=(np.array([]), np.array([]), np.array([]), np.array([])),
|
||||
recv_metadata=RecvMetadata(
|
||||
recv_primary_mask=np.array([]),
|
||||
recv_counts=np.array([]),
|
||||
recv_expert_ids=np.array([]),
|
||||
recv_dst_rows=np.array([]),
|
||||
),
|
||||
is_async_enabled=self.is_async,
|
||||
cuda_device_index=self.cuda_device_index,
|
||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||
@ -985,17 +989,20 @@ class EplbState:
|
||||
model_state.model.expert_weights[model_state.layer_to_transfer]
|
||||
]
|
||||
buffers_group = [model_state.expert_buffer]
|
||||
new_indices_group = (
|
||||
model_state.new_physical_to_logical_map[
|
||||
model_state.layer_to_transfer : model_state.layer_to_transfer + 1
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
move_from_buffer(
|
||||
weights_group=weights_group,
|
||||
buffers_group=buffers_group,
|
||||
is_unchanged=model_state.is_unchanged,
|
||||
is_received_locally=model_state.is_received_locally,
|
||||
recv_metadata=model_state.recv_metadata,
|
||||
new_indices_group=model_state.new_physical_to_logical_map[
|
||||
model_state.layer_to_transfer : model_state.layer_to_transfer + 1
|
||||
]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
new_indices_group=new_indices_group,
|
||||
ep_group=ep_group,
|
||||
)
|
||||
transferred_layer = model_state.layer_to_transfer
|
||||
|
||||
@ -7,6 +7,7 @@ This involves the exchange of expert weights between GPUs.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -24,6 +25,26 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecvMetadata:
|
||||
"""Metadata describing remote receives during EPLB rebalancing."""
|
||||
|
||||
recv_primary_mask: np.ndarray
|
||||
"""Mask of (layer_group_size, num_local_experts)
|
||||
indicating primary experts received."""
|
||||
recv_counts: np.ndarray
|
||||
"""Number of received experts for each layer."""
|
||||
recv_expert_ids: np.ndarray
|
||||
"""Expert ids (layer_group_size, num_local_experts) of remote primary experts."""
|
||||
recv_dst_rows: np.ndarray
|
||||
"""Target expert indices (layer_group_size, num_local_experts)
|
||||
in local tensors to send."""
|
||||
|
||||
|
||||
# Type alias for the result of move_to_buffer or transfer_layer
|
||||
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
|
||||
|
||||
|
||||
def get_ep_ranks_with_experts_batch(
|
||||
expert_ids: np.ndarray,
|
||||
num_local_experts: int,
|
||||
@ -139,11 +160,28 @@ def move_to_buffer(
|
||||
buffers_group: Sequence[Sequence[torch.Tensor]],
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
ep_group: ProcessGroup,
|
||||
) -> tuple[
|
||||
np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||
]:
|
||||
) -> MoveToBufferResult:
|
||||
"""
|
||||
Perform expert weights rearrangement of a group of layers.
|
||||
Rearranges expert weights across a group of layers
|
||||
during mixture-of-experts (MoE) expert parallel rebalancing.
|
||||
|
||||
Args:
|
||||
num_local_experts: Number of local experts.
|
||||
old_indices_group: (num_layers, num_experts_total) ndarray of current
|
||||
(old) global-to-local expert assignments.
|
||||
new_indices_group: (num_layers, num_experts_total) ndarray of desired
|
||||
(new) global-to-local assignments after rebalance.
|
||||
expert_weights_group: Original expert weights for each layer.
|
||||
buffers_group: List of per-layer intermediate buffers (one per tensor).
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
ep_group: Distributed process group for expert parallel comms.
|
||||
|
||||
Returns:
|
||||
is_unchanged (np.ndarray): (num_layers, num_local_experts), True where an
|
||||
expert row is unchanged after rebalance.
|
||||
is_received_locally (np.ndarray): (num_layers, num_local_experts), True
|
||||
where a row can be updated from local data.
|
||||
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||
"""
|
||||
assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group)
|
||||
group_size = len(old_indices_group)
|
||||
@ -214,10 +252,10 @@ def move_to_buffer(
|
||||
desired_experts, return_index=True
|
||||
)
|
||||
dst_rows = desired_dsts[uniq_indices]
|
||||
layer_send_count = int(uniq_recv_experts.shape[0])
|
||||
recv_counts[layer_idx] = layer_send_count
|
||||
recv_expert_ids[layer_idx, :layer_send_count] = uniq_recv_experts
|
||||
recv_dst_rows[layer_idx, :layer_send_count] = dst_rows
|
||||
layer_recv_count = int(uniq_recv_experts.shape[0])
|
||||
recv_counts[layer_idx] = layer_recv_count
|
||||
recv_expert_ids[layer_idx, :layer_recv_count] = uniq_recv_experts
|
||||
recv_dst_rows[layer_idx, :layer_recv_count] = dst_rows
|
||||
recv_primary_mask[layer_idx, dst_rows] = True
|
||||
else:
|
||||
recv_counts[layer_idx] = 0
|
||||
@ -367,7 +405,12 @@ def move_to_buffer(
|
||||
return (
|
||||
is_unchanged,
|
||||
is_received_locally,
|
||||
(recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows),
|
||||
RecvMetadata(
|
||||
recv_primary_mask=recv_primary_mask,
|
||||
recv_counts=recv_counts,
|
||||
recv_expert_ids=recv_expert_ids,
|
||||
recv_dst_rows=recv_dst_rows,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -376,21 +419,42 @@ def move_from_buffer(
|
||||
buffers_group: Sequence[Sequence[torch.Tensor]],
|
||||
is_unchanged: np.ndarray,
|
||||
is_received_locally: np.ndarray,
|
||||
recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
|
||||
recv_metadata: RecvMetadata,
|
||||
new_indices_group: np.ndarray,
|
||||
ep_group: ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Copies expert weights from communication buffers back to the target weight tensors,
|
||||
after EPLB rebalancing.
|
||||
|
||||
Args:
|
||||
weights_group: Groups of consecutive MoE layers, each containing one or more
|
||||
weight tensors.
|
||||
buffers_group: Intermediate buffers matching weights_group..
|
||||
is_unchanged: (num_layers, num_local_experts), True
|
||||
where an expert row is unchanged after rebalance.
|
||||
is_received_locally: (num_layers, num_local_experts), True
|
||||
where a row can be updated from local data.
|
||||
recv_metadata: RecvMetadata containing remote receive metadata.
|
||||
new_indices_group: np.ndarray giving for each layer the mapping from local rows
|
||||
to desired (possibly global) expert id, after rebalance.
|
||||
ep_group: torch.distributed.ProcessGroup for expert parallel communication
|
||||
domain.
|
||||
"""
|
||||
assert (
|
||||
len(weights_group)
|
||||
== len(buffers_group)
|
||||
== len(is_unchanged)
|
||||
== len(is_received_locally)
|
||||
== len(recv_metadata[0])
|
||||
== len(recv_metadata.recv_primary_mask)
|
||||
== len(new_indices_group)
|
||||
), "Unmatching layer group size"
|
||||
ep_rank = ep_group.rank()
|
||||
group_size = len(is_unchanged)
|
||||
recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows = recv_metadata
|
||||
recv_primary_mask = recv_metadata.recv_primary_mask
|
||||
recv_counts = recv_metadata.recv_counts
|
||||
recv_expert_ids = recv_metadata.recv_expert_ids
|
||||
recv_dst_rows = recv_metadata.recv_dst_rows
|
||||
num_local_experts = is_unchanged.shape[1]
|
||||
# Mask for rows to copy back from buffers:
|
||||
# copy if locally received OR remote primary recv
|
||||
@ -468,9 +532,7 @@ async def transfer_layer(
|
||||
layer: int = 0,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> tuple[
|
||||
np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||
]:
|
||||
) -> MoveToBufferResult:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
|
||||
@ -488,6 +550,13 @@ async def transfer_layer(
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
|
||||
Returns:
|
||||
is_unchanged (np.ndarray): (1, num_local_experts), True where expert
|
||||
is left unchanged.
|
||||
is_received_locally (np.ndarray): (1, num_local_experts), True where expert
|
||||
is not copied locally.
|
||||
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||
"""
|
||||
ep_size = ep_group.size()
|
||||
if rank_mapping is not None:
|
||||
@ -733,4 +802,4 @@ def _map_new_expert_indices_with_rank_mapping(
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
__all__ = ["transfer_layer", "move_from_buffer"]
|
||||
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata", "MoveToBufferResult"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user