diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py new file mode 100644 index 000000000000..b69647b00586 --- /dev/null +++ b/vllm/distributed/device_communicators/all2all.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.forward_context import get_forward_context + + +class All2AllBase: + + def __init__(self, cpu_group, model): + self.cpu_group = cpu_group + + # compute some common properties + from vllm.distributed.parallel_state import (get_dp_group, + get_ep_group, + get_tp_group, + in_the_same_node_as) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + self.ep_group = get_ep_group() + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.intranode = in_the_same_node_as(cpu_group, source_rank=0) + self.internode = not self.intranode + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + +class NaiveAll2All(All2AllBase): + """ + A naive implementation of all2all communication. + It uses all-reduce under the hood, which is not + efficient at all. The main purpose is for testing and + debugging. + """ + + def __init__(self, cpu_group, model): + super().__init__(cpu_group, model) + + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(self.dp_world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + self.dp_group.broadcast(buffer[start:end, :], idx) + + return buffer + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = self.dp_group.all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] + return hidden_states + + def destroy(self): + pass diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 240313b98c88..c313b66ed8b2 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Optional, Tuple import torch import torch.distributed as dist @@ -149,3 +149,27 @@ class DeviceCommunicatorBase: def destroy(self): pass + + def prepare_communication_buffer_for_model(self, + model: torch.nn.Module) -> None: + """ + Prepare the communication buffer for the model. + This is a no-op in the base class. + """ + pass + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Dispatch the hidden states and router logits to the appropriate device. + This is a no-op in the base class. + """ + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Combine the hidden states and router logits from the appropriate device. + This is a no-op in the base class. + """ + return hidden_states diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 8bca278f3888..7a90d63973de 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Optional, Tuple import torch from torch.distributed import ProcessGroup +import vllm.envs as envs + +from .all2all import All2AllBase from .base_device_communicator import DeviceCommunicatorBase @@ -23,9 +26,13 @@ class CudaCommunicator(DeviceCommunicatorBase): from vllm.distributed.parallel_state import ( _ENABLE_CUSTOM_ALL_REDUCE) use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE - use_pynccl = True + + # ep does not use pynccl + use_pynccl = "ep" not in unique_name self.use_pynccl = use_pynccl + self.use_all2all = "ep" in unique_name + self.all2all_impl: Optional[All2AllBase] = None self.use_custom_allreduce = use_custom_allreduce # lazy import to avoid documentation build error @@ -129,3 +136,31 @@ class CudaCommunicator(DeviceCommunicatorBase): self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None + if self.all2all_impl is not None: + self.all2all_impl.destroy() + self.all2all_impl = None + + def prepare_communication_buffer_for_model(self, + model: torch.nn.Module) -> None: + """ + Prepare the communication buffer for the model. + """ + if not self.use_all2all: + return + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2All + self.all2all_impl = NaiveAll2All(self.cpu_group, model) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_impl is not None + hidden_states, router_logits = self.all2all_impl.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert self.all2all_impl is not None + hidden_states = self.all2all_impl.combine(hidden_states) + return hidden_states diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cb9658ce1004..4a2a95d94b54 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -757,6 +757,22 @@ class GroupCoordinator: if self.mq_broadcaster is not None: self.mq_broadcaster = None + def prepare_communication_buffer_for_model(self, model: torch.nn.Module): + if self.device_communicator is not None: + self.device_communicator.prepare_communication_buffer_for_model( + model) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.device_communicator is not None: + return self.device_communicator.dispatch(hidden_states, + router_logits) + + def combine(self, hidden_states) -> torch.Tensor: + if self.device_communicator is not None: + return self.device_communicator.combine(hidden_states) + _WORLD: Optional[GroupCoordinator] = None @@ -816,6 +832,14 @@ def get_dp_group() -> GroupCoordinator: return _DP +_EP: Optional[GroupCoordinator] = None + + +def get_ep_group() -> GroupCoordinator: + assert _EP is not None, ("expert parallel group is not initialized") + return _EP + + def get_pp_group() -> GroupCoordinator: assert _PP is not None, ( "pipeline model parallel group is not initialized") @@ -1001,10 +1025,21 @@ def initialize_model_parallel( backend, group_name="dp") + global _EP + assert _EP is None, ("expert parallel group is already initialized") + group_ranks = all_ranks.transpose(1, 2).reshape( + -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _EP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="ep") + logger.info( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s", rank, world_size, - _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, + _EP.rank_in_group) def ensure_model_parallel_initialized( @@ -1035,6 +1070,23 @@ def ensure_model_parallel_initialized( f"{pipeline_model_parallel_size=}") +def prepare_communication_buffer_for_model(model: torch.nn.Module): + """Prepare the communication buffer for the model. + Traditional communication libraries like NCCL are almost + model agnostic. However, emerging new communication libraries like + MoE all2all (DeepEP) usually allocate the communication buffer + based on the model shape for optimal performance. + """ + if _TP is not None: + _TP.prepare_communication_buffer_for_model(model) + if _PP is not None: + _PP.prepare_communication_buffer_for_model(model) + if _DP is not None: + _DP.prepare_communication_buffer_for_model(model) + if _EP is not None: + _EP.prepare_communication_buffer_for_model(model) + + def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" return (_TP is not None and _PP is not None) @@ -1095,6 +1147,11 @@ def destroy_model_parallel(): _DP.destroy() _DP = None + global _EP + if _EP: + _EP.destroy() + _EP = None + def destroy_distributed_environment(): global _WORLD diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index a8f292c6e31f..7dd8389c94ff 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -362,6 +362,11 @@ def stateless_destroy_torch_distributed_process_group( stateless_init_torch_distributed_process_group(). """ # Lazy import for non-CUDA backends. - from torch.distributed.distributed_c10d import _shutdown_backend - _shutdown_backend(pg) + try: + # pytorch <= 2.6 + from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) + except ImportError: + # pytorch >= 2.7 + pg.shutdown() _unregister_process_group(pg.group_name) diff --git a/vllm/envs.py b/vllm/envs.py index 0c742bf05623..9d585bf3578e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -115,6 +115,7 @@ if TYPE_CHECKING: VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 + VLLM_ALL2ALL_BACKEND: str = "naive" def get_default_cache_root(): @@ -764,6 +765,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # Port used for NIXL handshake between remote agents. "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), + + # all2all backend for vllm's expert parallel communication + "VLLM_ALL2ALL_BACKEND": + lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index eae029b33e80..d745a15e3d22 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -10,7 +10,8 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs from vllm.config import get_current_vllm_config -from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.forward_context import ForwardContext, get_forward_context @@ -832,24 +833,6 @@ class FusedMoE(torch.nn.Module): return topk_weights, topk_ids - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(get_dp_group().world_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) - - return buffer - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.use_direct_call: @@ -863,14 +846,8 @@ class FusedMoE(torch.nn.Module): assert self.quant_method is not None if self.dp_size > 1: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) - + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -891,12 +868,7 @@ class FusedMoE(torch.nn.Module): ) if self.dp_size > 1: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - - all_hidden_states = get_dp_group().all_reduce(final_hidden_states) - final_hidden_states = all_hidden_states[start:end, :] + final_hidden_states = get_ep_group().combine(final_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c2c8533c88f4..1b16f273a6de 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,7 +19,8 @@ from vllm.config import (CompilationLevel, VllmConfig, from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 -from vllm.distributed.parallel_state import get_pp_group, graph_capture +from vllm.distributed.parallel_state import ( + get_pp_group, graph_capture, prepare_communication_buffer_for_model) from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -1457,6 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): logger.info("Model loading took %.4f GiB and %.6f seconds", self.model_memory_usage / GiB_bytes, time_after_load - time_before_load) + prepare_communication_buffer_for_model(self.model) def _get_prompt_logprobs_dict( self,