mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:35:01 +08:00
[EPLB] Optimize EPLB for Async Rearrange Experts (#22179)
Signed-off-by: David Chen <530634352@qq.com> Co-authored-by: SunChenxiang123 <1291824390@qq.com>
This commit is contained in:
parent
4de87866a8
commit
2601f18a82
@ -1,13 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.eplb.rebalance_execute import (
|
||||
move_from_buffer,
|
||||
rearrange_expert_weights_inplace,
|
||||
transfer_layer,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_tp_group,
|
||||
@ -231,6 +236,100 @@ def verify_redundant_experts_have_same_weights(
|
||||
)
|
||||
|
||||
|
||||
def _test_async_transfer_layer_without_mtp_worker(
|
||||
env,
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
tp_group = get_tp_group()
|
||||
ep_group = tp_group.device_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [16, 32]
|
||||
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
device,
|
||||
old_indices,
|
||||
)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
cuda_stream = torch.cuda.Stream(device=device)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
|
||||
transfer_layer(
|
||||
old_global_expert_indices=old_indices,
|
||||
new_global_expert_indices=new_indices,
|
||||
expert_weights=expert_weights,
|
||||
expert_weights_buffer=expert_buffer,
|
||||
ep_group=ep_group,
|
||||
layer=layer_idx,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
|
||||
cuda_stream.synchronize()
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffer=expert_buffer,
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
experts_recv_loc=experts_recv_loc,
|
||||
new_indices=new_indices[layer_idx].tolist(),
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_with_redundancy(
|
||||
env, world_size, num_layers, num_local_experts, num_logical_experts
|
||||
) -> None:
|
||||
@ -399,6 +498,32 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size,num_layers,num_local_experts,num_logical_experts",
|
||||
[
|
||||
(2, 2, 2, 3),
|
||||
],
|
||||
)
|
||||
def test_async_transfer_layer_without_mtp(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
):
|
||||
"""Exercise async EPLB transfer path without MTP/spec decode."""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
distributed_run(
|
||||
_test_async_transfer_layer_without_mtp_worker,
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_no_change(world_size):
|
||||
"""
|
||||
|
||||
@ -10,10 +10,11 @@ from tests.utils import large_gpu_mark
|
||||
|
||||
def get_model_args(
|
||||
model_name: str,
|
||||
spec_model_name: str,
|
||||
spec_model_name: str | None,
|
||||
spec_method: str,
|
||||
tp_size: int,
|
||||
model_max_len: int,
|
||||
use_async: bool = False,
|
||||
) -> dict:
|
||||
speculative_config = {
|
||||
"method": spec_method,
|
||||
@ -37,6 +38,8 @@ def get_model_args(
|
||||
"enable_eplb": True,
|
||||
"max_model_len": model_max_len,
|
||||
}
|
||||
if use_async:
|
||||
model_args["eplb_config"] = {"use_async": True}
|
||||
return model_args
|
||||
|
||||
|
||||
@ -94,3 +97,37 @@ def test_eplb_spec_decode(
|
||||
measured_value - RTOL < expected_gsm8k_value
|
||||
and measured_value + RTOL > expected_gsm8k_value
|
||||
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@large_gpu_mark(min_gb=80)
|
||||
def test_eplb_spec_decode_qwen3_next_mtp_async() -> None:
|
||||
"""
|
||||
Ensure async EPLB works with MTP speculative decoding for Qwen3-Next.
|
||||
"""
|
||||
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
expected_gsm8k_value = 0.86
|
||||
|
||||
model_args = get_model_args(
|
||||
model_name="Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
spec_model_name=None,
|
||||
spec_method="mtp",
|
||||
tp_size=4,
|
||||
model_max_len=4096,
|
||||
use_async=True,
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
batch_size=64,
|
||||
num_fewshot=8,
|
||||
)
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (
|
||||
measured_value - RTOL < expected_gsm8k_value
|
||||
and measured_value + RTOL > expected_gsm8k_value
|
||||
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
|
||||
|
||||
@ -60,6 +60,10 @@ class EPLBConfig:
|
||||
Log the balancedness each step of expert parallelism.
|
||||
This is turned off by default since it will cause communication overhead.
|
||||
"""
|
||||
use_async: bool = False
|
||||
"""
|
||||
Whether to use non-blocking EPLB.
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
|
||||
115
vllm/distributed/eplb/async_worker.py
Normal file
115
vllm/distributed/eplb/async_worker.py
Normal file
@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
The async worker that transfers experts in the background.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .rebalance_execute import transfer_layer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .eplb_state import EplbState
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def start_async_worker(
|
||||
state: "EplbState",
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
) -> threading.Thread:
|
||||
ep_group = get_ep_group().device_group
|
||||
rank = ep_group.rank()
|
||||
device_index = state.cuda_device_index
|
||||
|
||||
def thread_target() -> None:
|
||||
assert device_index is not None
|
||||
torch.cuda.set_device(device_index)
|
||||
cuda_stream = torch.cuda.Stream(device=device_index)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
transfer_run_periodically(
|
||||
state=state,
|
||||
ep_group=ep_group,
|
||||
is_profile=is_profile,
|
||||
rank_mapping=rank_mapping,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - diagnostic path
|
||||
logger.exception("async loop error (Rank %d): %s", rank, str(exc))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
thread = threading.Thread(target=thread_target, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
async def transfer_run_periodically(
|
||||
state: "EplbState",
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
cuda_stream: torch.cuda.Stream = None,
|
||||
) -> None:
|
||||
while True:
|
||||
await asyncio.to_thread(state.rearrange_event.wait)
|
||||
logger.info("async worker woke up for EPLB transfer")
|
||||
|
||||
for model_state in state.model_states.values():
|
||||
if not model_state.is_async_enabled:
|
||||
continue
|
||||
current_num_layers = model_state.model.num_moe_layers
|
||||
while (
|
||||
model_state.rebalanced
|
||||
and model_state.layer_to_transfer < current_num_layers
|
||||
):
|
||||
if (
|
||||
not model_state.ep_buffer_ready
|
||||
and model_state.rebalanced
|
||||
and model_state.new_physical_to_logical_map is not None
|
||||
):
|
||||
await asyncio.to_thread(model_state.buffer_lock.acquire)
|
||||
try:
|
||||
if model_state.layer_to_transfer >= current_num_layers:
|
||||
break
|
||||
|
||||
(
|
||||
model_state.is_unchanged,
|
||||
model_state.is_received_locally,
|
||||
model_state.experts_recv_loc,
|
||||
) = await transfer_layer(
|
||||
old_global_expert_indices=model_state.physical_to_logical_map,
|
||||
new_global_expert_indices=model_state.new_physical_to_logical_map,
|
||||
expert_weights=model_state.model.expert_weights,
|
||||
expert_weights_buffer=model_state.expert_buffer,
|
||||
ep_group=ep_group,
|
||||
is_profile=is_profile,
|
||||
layer=model_state.layer_to_transfer,
|
||||
cuda_stream=cuda_stream,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
event = torch.cuda.Event(blocking=False)
|
||||
cuda_stream.record_event(event)
|
||||
model_state.buffer_ready_event = event
|
||||
model_state.ep_buffer_ready = 1
|
||||
finally:
|
||||
model_state.buffer_lock.release()
|
||||
else:
|
||||
if not model_state.rebalanced:
|
||||
break
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
state.rearrange_event.clear()
|
||||
@ -26,6 +26,7 @@ MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
|
||||
physical experts.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -43,8 +44,9 @@ from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
|
||||
from .async_worker import start_async_worker
|
||||
from .rebalance_algo import rebalance_experts
|
||||
from .rebalance_execute import rearrange_expert_weights_inplace
|
||||
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -132,6 +134,74 @@ class EplbModelState:
|
||||
"""
|
||||
model_name: str
|
||||
model: MixtureOfExperts
|
||||
expert_buffer: list[torch.Tensor]
|
||||
"""
|
||||
The buffer to store the expert weights during transfer.
|
||||
"""
|
||||
buffer_lock: threading.Lock
|
||||
"""
|
||||
The lock to protect the expert buffer.
|
||||
"""
|
||||
buffer_ready_event: torch.cuda.Event | None
|
||||
"""
|
||||
CUDA event recorded when the async worker finishes filling the buffer.
|
||||
The main thread waits on this before consuming the buffer.
|
||||
"""
|
||||
ep_buffer_ready: int
|
||||
"""
|
||||
The flag indicates whether the expert buffer is ready for transfer.
|
||||
0 or 1.
|
||||
"""
|
||||
layer_to_transfer: int
|
||||
"""
|
||||
The layer index to transfer in async mode.
|
||||
"""
|
||||
rebalanced: bool
|
||||
"""
|
||||
The flag indicates whether the experts rebalance have been computed.
|
||||
"""
|
||||
pending_global_ready_check: bool
|
||||
"""
|
||||
Whether the async EPLB needs to poll peers for buffer readiness.
|
||||
"""
|
||||
is_unchanged: list[bool]
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
The size is same as the num of physical experts in the current layer.
|
||||
"""
|
||||
is_received_locally: list[bool]
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
The size is same as the num of physical experts in the current layer.
|
||||
"""
|
||||
experts_recv_loc: dict[int, int]
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
The size is same as the num of physical experts in the current layer.
|
||||
"""
|
||||
is_async_enabled: bool
|
||||
"""
|
||||
The flag indicates whether the EPLB is running in async mode.
|
||||
"""
|
||||
cuda_device_index: int | None
|
||||
"""
|
||||
CUDA device index for the async EPLB worker thread.
|
||||
"""
|
||||
new_physical_to_logical_map: torch.Tensor | None = None
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
the size is same as physical_to_logical_map
|
||||
"""
|
||||
new_logical_to_physical_map: torch.Tensor | None = None
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
the size is same as logical_to_physical_map
|
||||
"""
|
||||
new_logical_replica_count: torch.Tensor | None = None
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
the size is same as logical_replica_count
|
||||
"""
|
||||
|
||||
|
||||
class EplbState:
|
||||
@ -164,12 +234,31 @@ class EplbState:
|
||||
Otherwise, the rearrangement will hang at collective
|
||||
communication calls.
|
||||
"""
|
||||
self.expert_rearrangement_step: int = 0
|
||||
self.expert_rearrangement_step_interval: int = 0
|
||||
"""
|
||||
Interval for expert rearrangement steps.
|
||||
This is a constant and is taken from the config.
|
||||
"""
|
||||
self.expert_rearrangement_step_interval: int = 0
|
||||
self.is_async: bool = False
|
||||
"""
|
||||
The flag indicates whether the EPLB is running in async mode.
|
||||
"""
|
||||
self.rearrange_event = threading.Event()
|
||||
"""
|
||||
Event to signal when a new rearrangement is needed for the async thread.
|
||||
"""
|
||||
self.async_worker: threading.Thread | None = None
|
||||
"""
|
||||
Background thread handling async transfers.
|
||||
"""
|
||||
self.cuda_device_index: int | None = None
|
||||
"""
|
||||
CUDA device index for the async EPLB worker thread.
|
||||
"""
|
||||
if self.device.type == "cuda":
|
||||
self.cuda_device_index = self.device.index
|
||||
if self.cuda_device_index is None and torch.cuda.is_available():
|
||||
self.cuda_device_index = torch.cuda.current_device()
|
||||
|
||||
@staticmethod
|
||||
def build_initial_global_physical_to_logical_map(
|
||||
@ -239,6 +328,8 @@ class EplbState:
|
||||
Build the initial EPLB state.
|
||||
"""
|
||||
self.validate_ep_configuration(model)
|
||||
self.is_async = self.parallel_config.eplb_config.use_async
|
||||
|
||||
physical_to_logical_map_list = (
|
||||
EplbState.build_initial_global_physical_to_logical_map(
|
||||
model.num_routed_experts,
|
||||
@ -368,7 +459,12 @@ class EplbState:
|
||||
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
|
||||
logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
logical_replica_count.copy_(new_logical_replica_count)
|
||||
else:
|
||||
new_physical_to_logical_map = None
|
||||
|
||||
new_logical_to_physical_map = None
|
||||
|
||||
new_logical_replica_count = None
|
||||
model.set_eplb_state(
|
||||
expert_load_pass,
|
||||
logical_to_physical_map,
|
||||
@ -385,15 +481,33 @@ class EplbState:
|
||||
)
|
||||
self.expert_rearrangement_step = 0
|
||||
|
||||
self.model_states[model_config.compute_hash()] = EplbModelState(
|
||||
physical_to_logical_map,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
expert_load_pass,
|
||||
expert_load_window,
|
||||
model_config.model,
|
||||
model,
|
||||
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
|
||||
|
||||
model_state = EplbModelState(
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
expert_load_pass=expert_load_pass,
|
||||
expert_load_window=expert_load_window,
|
||||
model_name=model_config.model,
|
||||
model=model,
|
||||
expert_buffer=expert_buffer,
|
||||
buffer_lock=threading.Lock(),
|
||||
buffer_ready_event=None,
|
||||
ep_buffer_ready=0,
|
||||
layer_to_transfer=0,
|
||||
rebalanced=False,
|
||||
pending_global_ready_check=False,
|
||||
is_unchanged=[],
|
||||
is_received_locally=[],
|
||||
experts_recv_loc={},
|
||||
is_async_enabled=self.is_async,
|
||||
cuda_device_index=self.cuda_device_index,
|
||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||
new_logical_to_physical_map=new_logical_to_physical_map,
|
||||
new_logical_replica_count=new_logical_replica_count,
|
||||
)
|
||||
self.model_states[model_config.compute_hash()] = model_state
|
||||
|
||||
def step(
|
||||
self,
|
||||
@ -420,7 +534,7 @@ class EplbState:
|
||||
- `max_tokens`: The maximum load across ranks.
|
||||
- `balancedness`: The ratio of average load to maximum load.
|
||||
"""
|
||||
|
||||
ep_group = get_ep_group().device_group
|
||||
if is_profile:
|
||||
self.rearrange(is_profile=True)
|
||||
return
|
||||
@ -488,7 +602,49 @@ class EplbState:
|
||||
# rearrangement step and perform rearrangement to ensure all ranks are
|
||||
# performing collective communication.
|
||||
self.expert_rearrangement_step += 1
|
||||
|
||||
if self.is_async:
|
||||
for eplb_model_state in self.model_states.values():
|
||||
if not eplb_model_state.is_async_enabled:
|
||||
continue
|
||||
|
||||
all_ranks_buffer_ready = False
|
||||
if eplb_model_state.pending_global_ready_check:
|
||||
all_ranks_buffer_ready = self._all_ranks_buffer_ready(
|
||||
eplb_model_state
|
||||
)
|
||||
if (
|
||||
eplb_model_state.is_async_enabled
|
||||
and eplb_model_state.ep_buffer_ready
|
||||
and all_ranks_buffer_ready
|
||||
):
|
||||
self.move_to_workspace(
|
||||
model_state=eplb_model_state,
|
||||
ep_group=ep_group,
|
||||
is_profile=is_profile,
|
||||
)
|
||||
if (
|
||||
eplb_model_state.layer_to_transfer
|
||||
>= eplb_model_state.model.num_moe_layers
|
||||
):
|
||||
self.post_eplb(eplb_model_state, is_profile)
|
||||
eplb_model_state.rebalanced = False
|
||||
eplb_model_state.layer_to_transfer = 0
|
||||
eplb_model_state.pending_global_ready_check = False
|
||||
logger.info(
|
||||
"finish async transfer for model %s rank %d layer %d",
|
||||
eplb_model_state.model_name,
|
||||
ep_group.rank(),
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
)
|
||||
|
||||
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
|
||||
if any(
|
||||
eplb_model_state.is_async_enabled and eplb_model_state.rebalanced
|
||||
for eplb_model_state in self.model_states.values()
|
||||
):
|
||||
# Still performing asynchronous rearrangement
|
||||
return
|
||||
self.expert_rearrangement_step = 0
|
||||
self.rearrange()
|
||||
|
||||
@ -524,7 +680,11 @@ class EplbState:
|
||||
if is_main_rank:
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.perf_counter()
|
||||
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
|
||||
logger.info(
|
||||
"Rearranging experts %s %s...",
|
||||
"(async mode)" if self.is_async else "sync mode",
|
||||
"(profile)" if is_profile else "",
|
||||
)
|
||||
|
||||
if global_expert_loads is None:
|
||||
# Map the physical expert load to global logical experts
|
||||
@ -593,6 +753,7 @@ class EplbState:
|
||||
model = eplb_model_state.model
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
|
||||
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
|
||||
# NOTE(yongji): scale down, we need to rebalance the experts on
|
||||
# remaining GPUs, transfer the experts while we haven't shutdown
|
||||
@ -608,7 +769,7 @@ class EplbState:
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
self.num_nodes = 1
|
||||
num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
@ -631,60 +792,216 @@ class EplbState:
|
||||
num_gpus,
|
||||
)
|
||||
|
||||
# Update expert weights
|
||||
rearrange_expert_weights_inplace(
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
new_physical_to_logical_map,
|
||||
eplb_model_state.model.expert_weights,
|
||||
ep_group,
|
||||
is_profile,
|
||||
rank_mapping,
|
||||
)
|
||||
if not eplb_model_state.is_async_enabled or is_profile:
|
||||
# Update expert weights
|
||||
rearrange_expert_weights_inplace(
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
new_physical_to_logical_map,
|
||||
eplb_model_state.model.expert_weights,
|
||||
ep_group,
|
||||
is_profile,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
if not is_profile:
|
||||
if (
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
!= new_physical_to_logical_map.shape[1]
|
||||
):
|
||||
eplb_model_state.physical_to_logical_map = (
|
||||
new_physical_to_logical_map.to(
|
||||
eplb_model_state.physical_to_logical_map.device
|
||||
if not is_profile:
|
||||
if (
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
!= new_physical_to_logical_map.shape[1]
|
||||
):
|
||||
eplb_model_state.physical_to_logical_map = (
|
||||
new_physical_to_logical_map.to(
|
||||
eplb_model_state.physical_to_logical_map.device
|
||||
)
|
||||
)
|
||||
else:
|
||||
eplb_model_state.physical_to_logical_map.copy_(
|
||||
new_physical_to_logical_map
|
||||
)
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert (
|
||||
max_physical_slots
|
||||
<= eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
)
|
||||
else:
|
||||
eplb_model_state.physical_to_logical_map.copy_(
|
||||
new_physical_to_logical_map
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(
|
||||
0,
|
||||
eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
- max_physical_slots,
|
||||
),
|
||||
value=-1,
|
||||
)
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert (
|
||||
max_physical_slots
|
||||
<= eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
)
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
eplb_model_state.logical_to_physical_map.copy_(
|
||||
new_logical_to_physical_map
|
||||
)
|
||||
eplb_model_state.logical_replica_count.copy_(
|
||||
new_logical_replica_count
|
||||
)
|
||||
if is_main_rank:
|
||||
assert time_start is not None
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.perf_counter()
|
||||
logger.info(
|
||||
"Rearranged experts%sin %.2f seconds.",
|
||||
" (profile) " if is_profile else " ",
|
||||
time_end - time_start,
|
||||
)
|
||||
else:
|
||||
device = eplb_model_state.physical_to_logical_map.device
|
||||
new_physical = new_physical_to_logical_map.to(device)
|
||||
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
padded_logical = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(
|
||||
0,
|
||||
eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
- max_physical_slots,
|
||||
),
|
||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
||||
value=-1,
|
||||
).to(eplb_model_state.logical_to_physical_map.device)
|
||||
new_replica = new_logical_replica_count.to(
|
||||
eplb_model_state.logical_replica_count.device
|
||||
)
|
||||
eplb_model_state.logical_to_physical_map.copy_(
|
||||
new_logical_to_physical_map
|
||||
)
|
||||
eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
|
||||
|
||||
if is_main_rank:
|
||||
assert time_start is not None
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.perf_counter()
|
||||
logger.info(
|
||||
"Rearranged experts%sin %.2f seconds.",
|
||||
" (profile) " if is_profile else " ",
|
||||
time_end - time_start,
|
||||
)
|
||||
eplb_model_state.new_physical_to_logical_map = new_physical
|
||||
eplb_model_state.new_logical_to_physical_map = padded_logical
|
||||
eplb_model_state.new_logical_replica_count = new_replica
|
||||
|
||||
eplb_model_state.rebalanced = True
|
||||
eplb_model_state.layer_to_transfer = 0
|
||||
eplb_model_state.pending_global_ready_check = True
|
||||
|
||||
# Signal async thread to start transferring layers
|
||||
if self.is_async and (not is_profile):
|
||||
self.rearrange_event.set()
|
||||
return None
|
||||
|
||||
def start_async_loop(
|
||||
self,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
):
|
||||
if not self.is_async:
|
||||
return
|
||||
if self.async_worker is None:
|
||||
self.async_worker = start_async_worker(
|
||||
self,
|
||||
rank_mapping=rank_mapping,
|
||||
is_profile=is_profile,
|
||||
)
|
||||
|
||||
def _update_layer_mapping_from_new(
|
||||
self, model_state: EplbModelState, layer: int
|
||||
) -> None:
|
||||
if (
|
||||
model_state.new_physical_to_logical_map is None
|
||||
or model_state.new_logical_to_physical_map is None
|
||||
or model_state.new_logical_replica_count is None
|
||||
):
|
||||
return
|
||||
|
||||
target_device = model_state.physical_to_logical_map.device
|
||||
new_physical = model_state.new_physical_to_logical_map
|
||||
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
|
||||
model_state.physical_to_logical_map = new_physical.to(target_device)
|
||||
else:
|
||||
model_state.physical_to_logical_map[layer].copy_(
|
||||
new_physical[layer].to(target_device)
|
||||
)
|
||||
|
||||
logical_device = model_state.logical_to_physical_map.device
|
||||
new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
|
||||
max_slots = model_state.logical_to_physical_map.shape[-1]
|
||||
slot_delta = max_slots - new_logical.shape[-1]
|
||||
if slot_delta > 0:
|
||||
new_logical = torch.nn.functional.pad(
|
||||
new_logical, (0, slot_delta), value=-1
|
||||
)
|
||||
model_state.logical_to_physical_map[layer].copy_(new_logical)
|
||||
|
||||
replica_device = model_state.logical_replica_count.device
|
||||
model_state.logical_replica_count[layer].copy_(
|
||||
model_state.new_logical_replica_count[layer].to(replica_device)
|
||||
)
|
||||
|
||||
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
|
||||
parallel_state = get_ep_group()
|
||||
cpu_group = getattr(parallel_state, "cpu_group", None)
|
||||
if cpu_group is not None and cpu_group.size() > 1:
|
||||
flag = torch.tensor(
|
||||
(int(model_state.ep_buffer_ready),), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
all_reduce(flag, group=cpu_group)
|
||||
return int(flag.item()) == cpu_group.size()
|
||||
|
||||
device_group = parallel_state.device_group
|
||||
if device_group.size() <= 1:
|
||||
return bool(model_state.ep_buffer_ready)
|
||||
|
||||
device = getattr(
|
||||
parallel_state, "device", model_state.physical_to_logical_map.device
|
||||
)
|
||||
flag = torch.tensor(
|
||||
(int(model_state.ep_buffer_ready),), dtype=torch.int32, device=device
|
||||
)
|
||||
all_reduce(flag, group=device_group)
|
||||
return int(flag.item()) == device_group.size()
|
||||
|
||||
def move_to_workspace(
|
||||
self,
|
||||
model_state: EplbModelState,
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
):
|
||||
if not model_state.buffer_lock.acquire(blocking=False):
|
||||
return
|
||||
try:
|
||||
assert model_state.new_physical_to_logical_map is not None
|
||||
device_index = model_state.cuda_device_index or self.cuda_device_index
|
||||
if model_state.buffer_ready_event is not None and device_index is not None:
|
||||
stream = torch.cuda.current_stream(device=device_index)
|
||||
stream.wait_event(model_state.buffer_ready_event)
|
||||
model_state.buffer_ready_event = None
|
||||
move_from_buffer(
|
||||
expert_weights=model_state.model.expert_weights[
|
||||
model_state.layer_to_transfer
|
||||
],
|
||||
expert_weights_buffer=model_state.expert_buffer,
|
||||
is_unchanged=model_state.is_unchanged,
|
||||
is_received_locally=model_state.is_received_locally,
|
||||
experts_recv_loc=model_state.experts_recv_loc,
|
||||
new_indices=model_state.new_physical_to_logical_map[
|
||||
model_state.layer_to_transfer
|
||||
].tolist(),
|
||||
ep_group=ep_group,
|
||||
)
|
||||
transferred_layer = model_state.layer_to_transfer
|
||||
self._update_layer_mapping_from_new(model_state, transferred_layer)
|
||||
# After the main thread consumes, advance layer_to_transfer
|
||||
model_state.layer_to_transfer += 1
|
||||
model_state.ep_buffer_ready = 0
|
||||
logger.info(
|
||||
"model %s successfully move_to_workspace layer %d",
|
||||
model_state.model_name,
|
||||
transferred_layer,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
model_state.buffer_lock.release()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Rank %d: buffer_lock release failed in move_to_workspace: %s",
|
||||
ep_group.rank(),
|
||||
str(e),
|
||||
)
|
||||
|
||||
def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None:
|
||||
assert model_state.new_physical_to_logical_map is not None
|
||||
assert model_state.new_logical_to_physical_map is not None
|
||||
assert model_state.new_logical_replica_count is not None
|
||||
if not is_profile:
|
||||
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
|
||||
self._update_layer_mapping_from_new(model_state, layer_idx)
|
||||
model_state.new_physical_to_logical_map = None
|
||||
model_state.new_logical_to_physical_map = None
|
||||
model_state.new_logical_replica_count = None
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
|
||||
@ -100,18 +100,19 @@ def get_ep_ranks_with_expert(
|
||||
return ranks_to_send, ranks_to_recv_actual
|
||||
|
||||
|
||||
def shuffle_layer(
|
||||
def move_to_buffer(
|
||||
num_local_experts: int,
|
||||
ep_rank: int,
|
||||
old_indices: Sequence[int],
|
||||
new_indices: Sequence[int],
|
||||
expert_weights: Iterable[torch.Tensor],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
ep_group: ProcessGroup,
|
||||
) -> None:
|
||||
) -> tuple[list[bool], list[bool], dict[int, int]]:
|
||||
"""
|
||||
Perform expert weights rearrangement of one layer.
|
||||
"""
|
||||
ep_rank = ep_group.rank()
|
||||
local2global = partial(
|
||||
idx_local_to_global,
|
||||
local_cnt=num_local_experts,
|
||||
@ -137,7 +138,8 @@ def shuffle_layer(
|
||||
if old_indices[src_global] == new_indices[dst_global]:
|
||||
is_received_locally[dst] = True
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
buffer[dst].copy_(weight[src])
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
buffer[dst].copy_(weight[src], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
|
||||
@ -225,25 +227,115 @@ def shuffle_layer(
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops:
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# wait for the communication to finish
|
||||
return is_unchanged, is_received_locally, experts_recv_loc
|
||||
|
||||
|
||||
def move_from_buffer(
|
||||
expert_weights: Iterable[torch.Tensor],
|
||||
expert_weights_buffer: list[torch.Tensor],
|
||||
is_unchanged: list[bool],
|
||||
is_received_locally: list[bool],
|
||||
experts_recv_loc: dict[int, int],
|
||||
new_indices: Sequence[int],
|
||||
ep_group: ProcessGroup,
|
||||
) -> None:
|
||||
ep_rank = ep_group.rank()
|
||||
num_local_experts = len(is_unchanged)
|
||||
|
||||
local2global = partial(
|
||||
idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
|
||||
)
|
||||
|
||||
# 5. Copy the weights from the buffer back to the original weights.
|
||||
for dst in range(num_local_experts):
|
||||
if is_unchanged[dst]:
|
||||
continue
|
||||
if is_received_locally[dst]:
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[dst])
|
||||
weight[dst].copy_(buffer[dst], non_blocking=True)
|
||||
else:
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
src = experts_recv_loc[expert]
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[src])
|
||||
weight[dst].copy_(buffer[src], non_blocking=True)
|
||||
|
||||
|
||||
async def transfer_layer(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
layer: int = 0,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> tuple[list[bool], list[bool], dict[int, int]]:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
|
||||
The value of the indices arguments are logical indices of the experts,
|
||||
while keys are physical.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
ep_group: The device process group for expert parallelism.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
"""
|
||||
ep_size = ep_group.size()
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
# A buffer to hold the expert weights in one layer during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
|
||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
old_indices=old_global_expert_indices[layer].tolist(),
|
||||
new_indices=new_global_expert_indices[layer].tolist(),
|
||||
expert_weights=expert_weights[layer],
|
||||
expert_weights_buffer=expert_weights_buffer,
|
||||
cuda_stream=cuda_stream,
|
||||
ep_group=ep_group,
|
||||
)
|
||||
return is_unchanged, is_received_locally, experts_recv_loc
|
||||
|
||||
|
||||
def rearrange_expert_weights_inplace(
|
||||
@ -296,7 +388,6 @@ def rearrange_expert_weights_inplace(
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
|
||||
ep_rank = ep_group.rank()
|
||||
ep_size = ep_group.size()
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
@ -329,14 +420,24 @@ def rearrange_expert_weights_inplace(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for layer in range(num_moe_layers):
|
||||
shuffle_layer(
|
||||
num_local_physical_experts,
|
||||
ep_rank,
|
||||
old_global_expert_indices_cpu[layer].tolist(),
|
||||
new_global_expert_indices_cpu[layer].tolist(),
|
||||
expert_weights[layer],
|
||||
expert_weights_buffer,
|
||||
ep_group,
|
||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
old_indices=old_global_expert_indices_cpu[layer].tolist(),
|
||||
new_indices=new_global_expert_indices_cpu[layer].tolist(),
|
||||
expert_weights=expert_weights[layer],
|
||||
expert_weights_buffer=expert_weights_buffer,
|
||||
cuda_stream=None,
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights[layer],
|
||||
expert_weights_buffer=expert_weights_buffer,
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
experts_recv_loc=experts_recv_loc,
|
||||
new_indices=new_global_expert_indices[layer].tolist(),
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
|
||||
@ -428,4 +529,4 @@ def _map_new_expert_indices_with_rank_mapping(
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
__all__ = ["rearrange_expert_weights_inplace"]
|
||||
__all__ = ["transfer_layer", "move_from_buffer"]
|
||||
|
||||
@ -3370,6 +3370,8 @@ class GPUModelRunner(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
if self.eplb_state.is_async:
|
||||
self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
|
||||
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user