mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[core][distributed] add ep group and all2all interface (#18077)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
754b699cbe
commit
6266c57bae
93
vllm/distributed/device_communicators/all2all.py
Normal file
93
vllm/distributed/device_communicators/all2all.py
Normal file
@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
|
||||
class All2AllBase:
|
||||
|
||||
def __init__(self, cpu_group, model):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (get_dp_group,
|
||||
get_ep_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
self.ep_group = get_ep_group()
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
|
||||
self.internode = not self.intranode
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class NaiveAll2All(All2AllBase):
|
||||
"""
|
||||
A naive implementation of all2all communication.
|
||||
It uses all-reduce under the hood, which is not
|
||||
efficient at all. The main purpose is for testing and
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group, model):
|
||||
super().__init__(cpu_group, model)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
|
||||
all_hidden_states = self.dp_group.all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -149,3 +149,27 @@ class DeviceCommunicatorBase:
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def prepare_communication_buffer_for_model(self,
|
||||
model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
pass
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
from .all2all import All2AllBase
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
@ -23,9 +26,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
from vllm.distributed.parallel_state import (
|
||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_pynccl = True
|
||||
|
||||
# ep does not use pynccl
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_all2all = "ep" in unique_name
|
||||
self.all2all_impl: Optional[All2AllBase] = None
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
@ -129,3 +136,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.all2all_impl is not None:
|
||||
self.all2all_impl.destroy()
|
||||
self.all2all_impl = None
|
||||
|
||||
def prepare_communication_buffer_for_model(self,
|
||||
model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.use_all2all:
|
||||
return
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2All
|
||||
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_impl is not None
|
||||
hidden_states, router_logits = self.all2all_impl.dispatch(
|
||||
hidden_states, router_logits)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert self.all2all_impl is not None
|
||||
hidden_states = self.all2all_impl.combine(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -757,6 +757,22 @@ class GroupCoordinator:
|
||||
if self.mq_broadcaster is not None:
|
||||
self.mq_broadcaster = None
|
||||
|
||||
def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
|
||||
if self.device_communicator is not None:
|
||||
self.device_communicator.prepare_communication_buffer_for_model(
|
||||
model)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(hidden_states,
|
||||
router_logits)
|
||||
|
||||
def combine(self, hidden_states) -> torch.Tensor:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.combine(hidden_states)
|
||||
|
||||
|
||||
_WORLD: Optional[GroupCoordinator] = None
|
||||
|
||||
@ -816,6 +832,14 @@ def get_dp_group() -> GroupCoordinator:
|
||||
return _DP
|
||||
|
||||
|
||||
_EP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_ep_group() -> GroupCoordinator:
|
||||
assert _EP is not None, ("expert parallel group is not initialized")
|
||||
return _EP
|
||||
|
||||
|
||||
def get_pp_group() -> GroupCoordinator:
|
||||
assert _PP is not None, (
|
||||
"pipeline model parallel group is not initialized")
|
||||
@ -1001,10 +1025,21 @@ def initialize_model_parallel(
|
||||
backend,
|
||||
group_name="dp")
|
||||
|
||||
global _EP
|
||||
assert _EP is None, ("expert parallel group is already initialized")
|
||||
group_ranks = all_ranks.transpose(1, 2).reshape(
|
||||
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_EP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="ep")
|
||||
|
||||
logger.info(
|
||||
"rank %s in world size %s is assigned as "
|
||||
"DP rank %s, PP rank %s, TP rank %s", rank, world_size,
|
||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
|
||||
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
|
||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
|
||||
_EP.rank_in_group)
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
@ -1035,6 +1070,23 @@ def ensure_model_parallel_initialized(
|
||||
f"{pipeline_model_parallel_size=}")
|
||||
|
||||
|
||||
def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
||||
"""Prepare the communication buffer for the model.
|
||||
Traditional communication libraries like NCCL are almost
|
||||
model agnostic. However, emerging new communication libraries like
|
||||
MoE all2all (DeepEP) usually allocate the communication buffer
|
||||
based on the model shape for optimal performance.
|
||||
"""
|
||||
if _TP is not None:
|
||||
_TP.prepare_communication_buffer_for_model(model)
|
||||
if _PP is not None:
|
||||
_PP.prepare_communication_buffer_for_model(model)
|
||||
if _DP is not None:
|
||||
_DP.prepare_communication_buffer_for_model(model)
|
||||
if _EP is not None:
|
||||
_EP.prepare_communication_buffer_for_model(model)
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||
return (_TP is not None and _PP is not None)
|
||||
@ -1095,6 +1147,11 @@ def destroy_model_parallel():
|
||||
_DP.destroy()
|
||||
_DP = None
|
||||
|
||||
global _EP
|
||||
if _EP:
|
||||
_EP.destroy()
|
||||
_EP = None
|
||||
|
||||
|
||||
def destroy_distributed_environment():
|
||||
global _WORLD
|
||||
|
||||
@ -362,6 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
# Lazy import for non-CUDA backends.
|
||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||
_shutdown_backend(pg)
|
||||
try:
|
||||
# pytorch <= 2.6
|
||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||
_shutdown_backend(pg)
|
||||
except ImportError:
|
||||
# pytorch >= 2.7
|
||||
pg.shutdown()
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
@ -115,6 +115,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -764,6 +765,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Port used for NIXL handshake between remote agents.
|
||||
"VLLM_NIXL_SIDE_CHANNEL_PORT":
|
||||
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
|
||||
|
||||
# all2all backend for vllm's expert parallel communication
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -10,7 +10,8 @@ from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
@ -832,24 +833,6 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(get_dp_group().world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
if self.use_direct_call:
|
||||
@ -863,14 +846,8 @@ class FusedMoE(torch.nn.Module):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.dp_size > 1:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@ -891,12 +868,7 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
|
||||
if self.dp_size > 1:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
|
||||
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
|
||||
final_hidden_states = all_hidden_states[start:end, :]
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.)
|
||||
|
||||
@ -19,7 +19,8 @@ from vllm.config import (CompilationLevel, VllmConfig,
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, graph_capture, prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@ -1457,6 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||
self.model_memory_usage / GiB_bytes,
|
||||
time_after_load - time_before_load)
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
|
||||
def _get_prompt_logprobs_dict(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user