From 2ea50e977aac00c63e78990a7477bb91295df183 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 18 Sep 2025 10:52:58 -0500 Subject: [PATCH] Enable Allgather/ReduceScatter backend for NaiveAllToAll (#23964) Signed-off-by: Shu Wang. Signed-off-by: Tyler Michael Smith Signed-off-by: Shu Wang Co-authored-by: Tyler Michael Smith Co-authored-by: Tyler Michael Smith Co-authored-by: Michael Goin --- .../device_communicators/all2all.py | 39 +++++++++++++++++++ .../device_communicators/cuda_communicator.py | 4 ++ vllm/envs.py | 17 +++++--- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 427fd040fcb71..149df73d8667b 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -5,6 +5,7 @@ from typing import Any import torch import torch.distributed as dist +from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import has_deep_ep, has_pplx @@ -69,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase): pass +class AgRsAll2AllManager(All2AllManagerBase): + """ + An implementation of all2all communication based on + all-gather (dispatch) and reduce-scatter (combine). + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + """ + Gather hidden_states and router_logits from all dp ranks. + """ + sizes = get_forward_context( + ).dp_metadata.get_chunk_sizes_across_dp_rank() + hidden_states, router_logits = get_dp_group().all_gatherv( + [hidden_states, router_logits], + dim=0, + sizes=sizes, + ) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Reduce-scatter hidden_states across all dp ranks. + """ + sizes = get_forward_context( + ).dp_metadata.get_chunk_sizes_across_dp_rank() + hidden_states = get_dp_group().reduce_scatterv(hidden_states, + dim=0, + sizes=sizes) + return hidden_states + + def destroy(self): + pass + + class PPLXAll2AllManager(All2AllManagerBase): """ All2All communication based on PPLX kernels. diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 78c90b006ffc8..b2bf3bc3cc2ed 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase): from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + elif all2all_backend == "allgather_reducescatter": + from .all2all import AgRsAll2AllManager + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AllGather-ReduceScatter all2all manager.") elif all2all_backend == "pplx": from .all2all import PPLXAll2AllManager self.all2all_manager = PPLXAll2AllManager(self.cpu_group) diff --git a/vllm/envs.py b/vllm/envs.py index 72e1d5b0ede81..19e2f8635275d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -149,8 +149,11 @@ if TYPE_CHECKING: VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 - VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput", - "deepep_low_latency"] = "naive" + VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter"] = \ + "allgather_reducescatter" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False @@ -1124,14 +1127,18 @@ environment_variables: dict[str, Callable[[], Any]] = { # all2all backend for vllm's expert parallel communication # Available options: - # - "naive": naive all2all implementation using all-reduce + # - "naive": naive all2all implementation using broadcasts + # - "allgather_reducescatter": all2all implementation based on allgather and + # reducescatter # - "pplx": use pplx kernels # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels "VLLM_ALL2ALL_BACKEND": - env_with_choices("VLLM_ALL2ALL_BACKEND", "naive", + env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter", ["naive", "pplx", - "deepep_high_throughput", "deepep_low_latency"]), + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter"]), # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. # Both require compute capability 10.0 or above.