mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 22:50:11 +08:00
Move rebalance algo to async thread
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
6014dc26d3
commit
7ebd46fe76
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
||||
from .rebalance_execute import transfer_layer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .eplb_state import EplbState
|
||||
from .eplb_state import EplbModelState, EplbState
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -57,6 +57,42 @@ def start_async_worker(
|
||||
return thread
|
||||
|
||||
|
||||
def run_rebalance_experts(
|
||||
model_state: "EplbModelState",
|
||||
eplb_state: "EplbState",
|
||||
) -> None:
|
||||
assert model_state.eplb_stats is not None
|
||||
eplb_stats = model_state.eplb_stats
|
||||
# Move the global expert load window to CPU for computation.
|
||||
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
|
||||
# Compute new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = eplb_state.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
eplb_stats.num_replicas,
|
||||
eplb_stats.num_groups,
|
||||
eplb_stats.num_nodes,
|
||||
eplb_stats.num_gpus,
|
||||
model_state.physical_to_logical_map,
|
||||
)
|
||||
|
||||
# Move map to cpu
|
||||
model_state.new_physical_to_logical_map = new_physical_to_logical_map
|
||||
|
||||
max_slots = model_state.logical_to_physical_map.shape[-1]
|
||||
padded_logical = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
||||
value=-1,
|
||||
).to(model_state.logical_to_physical_map.device)
|
||||
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
|
||||
model_state.new_logical_to_physical_map = padded_logical
|
||||
model_state.new_logical_replica_count = new_replica
|
||||
|
||||
|
||||
async def transfer_run_periodically(
|
||||
state: "EplbState",
|
||||
ep_group: ProcessGroup,
|
||||
@ -71,6 +107,9 @@ async def transfer_run_periodically(
|
||||
for model_state in state.model_states.values():
|
||||
if not model_state.is_async_enabled:
|
||||
continue
|
||||
if not model_state.new_indices_computed:
|
||||
run_rebalance_experts(model_state, state)
|
||||
|
||||
current_num_layers = model_state.model.num_moe_layers
|
||||
while (
|
||||
model_state.rebalanced
|
||||
|
||||
@ -55,6 +55,35 @@ from .rebalance_execute import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbStats:
|
||||
"""
|
||||
Model stats used in EPLB rebalanding algorithm.
|
||||
"""
|
||||
|
||||
global_expert_load_window: torch.Tensor
|
||||
"""
|
||||
Experts load window.
|
||||
Shape: (window_size, num_moe_layers, num_physical_experts)
|
||||
"""
|
||||
num_replicas: int
|
||||
"""
|
||||
Number of physical experts.
|
||||
"""
|
||||
num_groups: int
|
||||
"""
|
||||
Number of expert groups.
|
||||
"""
|
||||
num_nodes: int
|
||||
"""
|
||||
Number of nodes.
|
||||
"""
|
||||
num_gpus: int
|
||||
"""
|
||||
Number of GPUs.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbModelState:
|
||||
"""EPLB metrics."""
|
||||
@ -168,6 +197,14 @@ class EplbModelState:
|
||||
"""
|
||||
Whether the async EPLB needs to poll peers for buffer readiness.
|
||||
"""
|
||||
new_indices_computed: bool
|
||||
"""
|
||||
The flag indicates whether the new indices have been computed.
|
||||
"""
|
||||
eplb_stats: EplbStats | None
|
||||
"""
|
||||
EPLB stats for the model.
|
||||
"""
|
||||
is_unchanged: np.ndarray
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
@ -510,6 +547,8 @@ class EplbState:
|
||||
layer_to_transfer=0,
|
||||
rebalanced=False,
|
||||
pending_global_ready_check=False,
|
||||
new_indices_computed=False,
|
||||
eplb_stats=None,
|
||||
is_unchanged=np.array([]),
|
||||
is_received_locally=np.array([]),
|
||||
recv_metadata=RecvMetadata(
|
||||
@ -806,21 +845,21 @@ class EplbState:
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
# Get new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
)
|
||||
|
||||
if not eplb_model_state.is_async_enabled or is_profile:
|
||||
# Get new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
)
|
||||
|
||||
# Update expert weights
|
||||
rearrange_expert_weights_inplace(
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
@ -877,27 +916,17 @@ class EplbState:
|
||||
gpu_elapsed,
|
||||
)
|
||||
else:
|
||||
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
padded_logical = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
||||
value=-1,
|
||||
).to(eplb_model_state.logical_to_physical_map.device)
|
||||
new_replica = new_logical_replica_count.to(
|
||||
eplb_model_state.logical_replica_count.device
|
||||
eplb_model_state.eplb_stats = EplbStats(
|
||||
global_expert_load_window=global_expert_load_window,
|
||||
num_replicas=num_replicas,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
num_gpus=num_gpus,
|
||||
)
|
||||
|
||||
# Move map to cpu in advance
|
||||
eplb_model_state.new_physical_to_logical_map = (
|
||||
new_physical_to_logical_map.cpu()
|
||||
)
|
||||
eplb_model_state.new_logical_to_physical_map = padded_logical
|
||||
eplb_model_state.new_logical_replica_count = new_replica
|
||||
|
||||
eplb_model_state.rebalanced = True
|
||||
eplb_model_state.layer_to_transfer = 0
|
||||
eplb_model_state.pending_global_ready_check = True
|
||||
|
||||
eplb_model_state.new_indices_computed = False
|
||||
# Signal async thread to start transferring layers
|
||||
if self.is_async and (not is_profile):
|
||||
self.rearrange_event.set()
|
||||
@ -993,11 +1022,9 @@ class EplbState:
|
||||
model_state.layer_to_transfer
|
||||
]
|
||||
expert_weights_buffer = model_state.expert_buffer
|
||||
new_indices = (
|
||||
model_state.new_physical_to_logical_map[model_state.layer_to_transfer]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
new_indices = model_state.new_physical_to_logical_map[
|
||||
model_state.layer_to_transfer
|
||||
].numpy()
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights,
|
||||
expert_weights_buffers=expert_weights_buffer,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user