[core][distributed] add ep group and all2all interface (#18077)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-05-14 10:46:49 +08:00 committed by GitHub
parent 754b699cbe
commit 6266c57bae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 234 additions and 41 deletions

View File

@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.forward_context import get_forward_context
class All2AllBase:
def __init__(self, cpu_group, model):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_ep_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
self.ep_group = get_ep_group()
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class NaiveAll2All(All2AllBase):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
efficient at all. The main purpose is for testing and
debugging.
"""
def __init__(self, cpu_group, model):
super().__init__(cpu_group, model)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx)
return buffer
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
all_hidden_states = self.dp_group.all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states
def destroy(self):
pass

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from typing import Optional, Tuple
import torch
import torch.distributed as dist
@ -149,3 +149,27 @@ class DeviceCommunicatorBase:
def destroy(self):
pass
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
This is a no-op in the base class.
"""
pass
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
return hidden_states

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from typing import Optional, Tuple
import torch
from torch.distributed import ProcessGroup
import vllm.envs as envs
from .all2all import All2AllBase
from .base_device_communicator import DeviceCommunicatorBase
@ -23,9 +26,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_pynccl = True
# ep does not use pynccl
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_all2all = "ep" in unique_name
self.all2all_impl: Optional[All2AllBase] = None
self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error
@ -129,3 +136,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.all2all_impl is not None:
self.all2all_impl.destroy()
self.all2all_impl = None
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2All
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_impl is not None
hidden_states = self.all2all_impl.combine(hidden_states)
return hidden_states

View File

@ -757,6 +757,22 @@ class GroupCoordinator:
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(
model)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
def combine(self, hidden_states) -> torch.Tensor:
if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states)
_WORLD: Optional[GroupCoordinator] = None
@ -816,6 +832,14 @@ def get_dp_group() -> GroupCoordinator:
return _DP
_EP: Optional[GroupCoordinator] = None
def get_ep_group() -> GroupCoordinator:
assert _EP is not None, ("expert parallel group is not initialized")
return _EP
def get_pp_group() -> GroupCoordinator:
assert _PP is not None, (
"pipeline model parallel group is not initialized")
@ -1001,10 +1025,21 @@ def initialize_model_parallel(
backend,
group_name="dp")
global _EP
assert _EP is None, ("expert parallel group is already initialized")
group_ranks = all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="ep")
logger.info(
"rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s", rank, world_size,
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group)
def ensure_model_parallel_initialized(
@ -1035,6 +1070,23 @@ def ensure_model_parallel_initialized(
f"{pipeline_model_parallel_size=}")
def prepare_communication_buffer_for_model(model: torch.nn.Module):
"""Prepare the communication buffer for the model.
Traditional communication libraries like NCCL are almost
model agnostic. However, emerging new communication libraries like
MoE all2all (DeepEP) usually allocate the communication buffer
based on the model shape for optimal performance.
"""
if _TP is not None:
_TP.prepare_communication_buffer_for_model(model)
if _PP is not None:
_PP.prepare_communication_buffer_for_model(model)
if _DP is not None:
_DP.prepare_communication_buffer_for_model(model)
if _EP is not None:
_EP.prepare_communication_buffer_for_model(model)
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (_TP is not None and _PP is not None)
@ -1095,6 +1147,11 @@ def destroy_model_parallel():
_DP.destroy()
_DP = None
global _EP
if _EP:
_EP.destroy()
_EP = None
def destroy_distributed_environment():
global _WORLD

View File

@ -362,6 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
stateless_init_torch_distributed_process_group().
"""
# Lazy import for non-CUDA backends.
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
try:
# pytorch <= 2.6
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
except ImportError:
# pytorch >= 2.7
pg.shutdown()
_unregister_process_group(pg.group_name)

View File

@ -115,6 +115,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
def get_default_cache_root():
@ -764,6 +765,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Port used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_PORT":
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
# all2all backend for vllm's expert parallel communication
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
}
# end-env-vars-definition

View File

@ -10,7 +10,8 @@ from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
@ -832,24 +833,6 @@ class FusedMoE(torch.nn.Module):
return topk_weights, topk_ids
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(get_dp_group().world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
if self.use_direct_call:
@ -863,14 +846,8 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None
if self.dp_size > 1:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@ -891,12 +868,7 @@ class FusedMoE(torch.nn.Module):
)
if self.dp_size > 1:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
final_hidden_states = all_hidden_states[start:end, :]
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)

View File

@ -19,7 +19,8 @@ from vllm.config import (CompilationLevel, VllmConfig,
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.distributed.parallel_state import (
get_pp_group, graph_capture, prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -1457,6 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Model loading took %.4f GiB and %.6f seconds",
self.model_memory_usage / GiB_bytes,
time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model)
def _get_prompt_logprobs_dict(
self,