Move rebalance algo to async thread

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-12 13:18:28 +00:00
parent 6014dc26d3
commit 7ebd46fe76
2 changed files with 103 additions and 37 deletions

View File

@ -17,7 +17,7 @@ from vllm.logger import init_logger
from .rebalance_execute import transfer_layer from .rebalance_execute import transfer_layer
if TYPE_CHECKING: if TYPE_CHECKING:
from .eplb_state import EplbState from .eplb_state import EplbModelState, EplbState
logger = init_logger(__name__) logger = init_logger(__name__)
@ -57,6 +57,42 @@ def start_async_worker(
return thread return thread
def run_rebalance_experts(
model_state: "EplbModelState",
eplb_state: "EplbState",
) -> None:
assert model_state.eplb_stats is not None
eplb_stats = model_state.eplb_stats
# Move the global expert load window to CPU for computation.
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
# Compute new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = eplb_state.policy.rebalance_experts(
global_expert_load_window,
eplb_stats.num_replicas,
eplb_stats.num_groups,
eplb_stats.num_nodes,
eplb_stats.num_gpus,
model_state.physical_to_logical_map,
)
# Move map to cpu
model_state.new_physical_to_logical_map = new_physical_to_logical_map
max_slots = model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
).to(model_state.logical_to_physical_map.device)
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
model_state.new_logical_to_physical_map = padded_logical
model_state.new_logical_replica_count = new_replica
async def transfer_run_periodically( async def transfer_run_periodically(
state: "EplbState", state: "EplbState",
ep_group: ProcessGroup, ep_group: ProcessGroup,
@ -71,6 +107,9 @@ async def transfer_run_periodically(
for model_state in state.model_states.values(): for model_state in state.model_states.values():
if not model_state.is_async_enabled: if not model_state.is_async_enabled:
continue continue
if not model_state.new_indices_computed:
run_rebalance_experts(model_state, state)
current_num_layers = model_state.model.num_moe_layers current_num_layers = model_state.model.num_moe_layers
while ( while (
model_state.rebalanced model_state.rebalanced

View File

@ -55,6 +55,35 @@ from .rebalance_execute import (
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class EplbStats:
"""
Model stats used in EPLB rebalanding algorithm.
"""
global_expert_load_window: torch.Tensor
"""
Experts load window.
Shape: (window_size, num_moe_layers, num_physical_experts)
"""
num_replicas: int
"""
Number of physical experts.
"""
num_groups: int
"""
Number of expert groups.
"""
num_nodes: int
"""
Number of nodes.
"""
num_gpus: int
"""
Number of GPUs.
"""
@dataclass @dataclass
class EplbModelState: class EplbModelState:
"""EPLB metrics.""" """EPLB metrics."""
@ -168,6 +197,14 @@ class EplbModelState:
""" """
Whether the async EPLB needs to poll peers for buffer readiness. Whether the async EPLB needs to poll peers for buffer readiness.
""" """
new_indices_computed: bool
"""
The flag indicates whether the new indices have been computed.
"""
eplb_stats: EplbStats | None
"""
EPLB stats for the model.
"""
is_unchanged: np.ndarray is_unchanged: np.ndarray
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
@ -510,6 +547,8 @@ class EplbState:
layer_to_transfer=0, layer_to_transfer=0,
rebalanced=False, rebalanced=False,
pending_global_ready_check=False, pending_global_ready_check=False,
new_indices_computed=False,
eplb_stats=None,
is_unchanged=np.array([]), is_unchanged=np.array([]),
is_received_locally=np.array([]), is_received_locally=np.array([]),
recv_metadata=RecvMetadata( recv_metadata=RecvMetadata(
@ -806,21 +845,21 @@ class EplbState:
for eplb_model_state, global_expert_load_window in zip( for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows self.model_states.values(), global_expert_load_windows
): ):
# Get new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
eplb_model_state.physical_to_logical_map,
)
if not eplb_model_state.is_async_enabled or is_profile: if not eplb_model_state.is_async_enabled or is_profile:
# Get new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
eplb_model_state.physical_to_logical_map,
)
# Update expert weights # Update expert weights
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
eplb_model_state.physical_to_logical_map, eplb_model_state.physical_to_logical_map,
@ -877,27 +916,17 @@ class EplbState:
gpu_elapsed, gpu_elapsed,
) )
else: else:
max_slots = eplb_model_state.logical_to_physical_map.shape[-1] eplb_model_state.eplb_stats = EplbStats(
padded_logical = torch.nn.functional.pad( global_expert_load_window=global_expert_load_window,
new_logical_to_physical_map, num_replicas=num_replicas,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])), num_groups=num_groups,
value=-1, num_nodes=num_nodes,
).to(eplb_model_state.logical_to_physical_map.device) num_gpus=num_gpus,
new_replica = new_logical_replica_count.to(
eplb_model_state.logical_replica_count.device
) )
# Move map to cpu in advance
eplb_model_state.new_physical_to_logical_map = (
new_physical_to_logical_map.cpu()
)
eplb_model_state.new_logical_to_physical_map = padded_logical
eplb_model_state.new_logical_replica_count = new_replica
eplb_model_state.rebalanced = True eplb_model_state.rebalanced = True
eplb_model_state.layer_to_transfer = 0 eplb_model_state.layer_to_transfer = 0
eplb_model_state.pending_global_ready_check = True eplb_model_state.pending_global_ready_check = True
eplb_model_state.new_indices_computed = False
# Signal async thread to start transferring layers # Signal async thread to start transferring layers
if self.is_async and (not is_profile): if self.is_async and (not is_profile):
self.rearrange_event.set() self.rearrange_event.set()
@ -993,11 +1022,9 @@ class EplbState:
model_state.layer_to_transfer model_state.layer_to_transfer
] ]
expert_weights_buffer = model_state.expert_buffer expert_weights_buffer = model_state.expert_buffer
new_indices = ( new_indices = model_state.new_physical_to_logical_map[
model_state.new_physical_to_logical_map[model_state.layer_to_transfer] model_state.layer_to_transfer
.cpu() ].numpy()
.numpy()
)
move_from_buffer( move_from_buffer(
expert_weights=expert_weights, expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer, expert_weights_buffers=expert_weights_buffer,