diff --git a/docs/design/dbo.md b/docs/design/dbo.md index d92c47c80f951..f2d98ccd063fa 100644 --- a/docs/design/dbo.md +++ b/docs/design/dbo.md @@ -34,10 +34,10 @@ To enable the DBO system pass in the `--enable-dbo` argument to your vllm serve * `--dbo-decode-token-threshold` the minimum number of tokens in a decode-only batch required to enable DBO for that batch * `--dbo-prefill-token-threshold` the minimum number of tokens in a batch containing at least one prefill required to enable DBO for that batch -Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `VLLM_ALL2ALL_BACKEND` environment variable must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests. +Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `--all2all-backend` argument must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests. Below is a command that will spin up a two DP rank server with expert parallelism and DBO enabled. -EX: `VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo` +EX: `vllm serve deepseek-ai/DeepSeek-V2-Lite --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo --all2all-backend deepep_low_latency` Note that there must be at least two GPUs visible in `CUDA_VISIBLE_DEVICES` diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 93ed383395f27..cd6515dde75ef 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -14,13 +14,16 @@ Before using EP, you need to install the necessary dependencies. We are actively ### Backend Selection Guide -vLLM provides three communication backends for EP: +vLLM provides multiple communication backends for EP. Use `--all2all-backend` to select one: | Backend | Use Case | Features | Best For | |---------|----------|----------|----------| -| `pplx` | Single node | Chunked prefill support | Development, best for intra-node deployments | -| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout | High-throughput scenarios, prefill-dominated workloads | -| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout | Low-latency scenarios, decode-dominated workloads | +| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration | +| `pplx` | Single node | Chunked prefill support, efficient intra-node communication | Single-node deployments, development | +| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios | +| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios | +| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes | +| `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production | ## Single Node Deployment @@ -47,11 +50,11 @@ The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parall ```bash # Single node EP deployment with pplx backend -VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 \ - vllm serve deepseek-ai/DeepSeek-V3-0324 \ - --tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU --data-parallel-size 8 \ # Data parallelism across 8 processes - --enable-expert-parallel # Enable expert parallelism + --enable-expert-parallel \ # Enable expert parallelism + --all2all-backend pplx # Use pplx communication backend ``` ## Multi-Node Deployment @@ -70,8 +73,8 @@ The following example deploys `DeepSeek-V3-0324` across 2 nodes using `deepep_lo ```bash # Node 1 (Primary - handles incoming requests) -VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \ - vllm serve deepseek-ai/DeepSeek-V3-0324 \ +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --all2all-backend deepep_low_latency \ --tensor-parallel-size 1 \ # TP size per node --enable-expert-parallel \ # Enable EP --data-parallel-size 16 \ # Total DP size across all nodes @@ -81,8 +84,8 @@ VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \ --api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended) # Node 2 (Secondary - headless mode, no API server) -VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \ - vllm serve deepseek-ai/DeepSeek-V3-0324 \ +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --all2all-backend deepep_low_latency \ --tensor-parallel-size 1 \ # TP size per node --enable-expert-parallel \ # Enable EP --data-parallel-size 16 \ # Total DP size across all nodes @@ -169,11 +172,12 @@ Single node deployment with EPLB enabled: ```bash # Single node with EPLB load balancing -VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V3-0324 \ - --tensor-parallel-size 1 \ # Tensor parallelism - --data-parallel-size 8 \ # Data parallelism - --enable-expert-parallel \ # Enable EP - --enable-eplb \ # Enable load balancer +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --tensor-parallel-size 1 \ # Tensor parallelism + --data-parallel-size 8 \ # Data parallelism + --enable-expert-parallel \ # Enable EP + --all2all-backend pplx \ # Use pplx communication backend + --enable-eplb \ # Enable load balancer --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' ``` diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 084e458f88309..b7ef0fef68330 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -113,6 +113,25 @@ class ParallelConfig: with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1 will have experts [1, 3]. This strategy can help improve load balancing for grouped expert models with no redundant experts.""" + all2all_backend: ( + Literal[ + "naive", + "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter", + "flashinfer_all2allv", + ] + | None + ) = None + """All2All backend for MoE expert parallel communication. If not set, uses + the value from VLLM_ALL2ALL_BACKEND environment variable. Available options: + - "naive": Naive all2all implementation using broadcasts + - "allgather_reducescatter": All2all based on allgather and reducescatter + - "pplx": Use pplx kernels + - "deepep_high_throughput": Use deepep high-throughput kernels + - "deepep_low_latency": Use deepep low-latency kernels + - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl""" num_redundant_experts: int | None = None """`num_redundant_experts` is deprecated and has been replaced with `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. @@ -341,7 +360,7 @@ class ParallelConfig: @property def use_sequence_parallel_moe(self) -> bool: return ( - envs.VLLM_ALL2ALL_BACKEND + self.all2all_backend in ( "allgather_reducescatter", "naive", @@ -390,7 +409,7 @@ class ParallelConfig: factors.append(self.tensor_parallel_size) factors.append(self.enable_expert_parallel) factors.append(self.data_parallel_size) - factors.append(envs.VLLM_ALL2ALL_BACKEND) + factors.append(self.all2all_backend) factors.append(self.enable_eplb) if self.enable_eplb: factors.append(self.eplb_config.log_balancedness) @@ -400,6 +419,16 @@ class ParallelConfig: return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Set all2all_backend from env var if not specified, with deprecation warning + if self.all2all_backend is None: + self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if envs.is_set("VLLM_ALL2ALL_BACKEND"): + logger.warning_once( + "VLLM_ALL2ALL_BACKEND environment variable is deprecated and " + "will be removed in a future release. Please use the " + "--all2all-backend command-line argument instead." + ) + # Forward deprecated fields to their new location if self.num_redundant_experts is not None: self.eplb_config.num_redundant_experts = self.num_redundant_experts diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index c94101bf608f2..4da164c1a0a96 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -523,13 +523,13 @@ class VllmConfig: ) if self.parallel_config.enable_dbo: - a2a_backend = envs.VLLM_ALL2ALL_BACKEND + a2a_backend = self.parallel_config.all2all_backend assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( "Microbatching currently only supports the deepep_low_latency and " f"deepep_high_throughput all2all backend. {a2a_backend} is not " - "supported. To fix set the VLLM_ALL2ALL_BACKEND environment " - "variable to deepep_low_latency or deepep_high_throughput and " - "install the DeepEP kernels." + "supported. To fix use --all2all-backend=deepep_low_latency or " + "--all2all-backend=deepep_high_throughput and install the DeepEP" + " kernels." ) if not self.model_config.disable_cascade_attn: diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 007c65acedb9b..9566dbac7f22f 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -111,6 +111,7 @@ class DeviceCommunicatorBase: self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) use_ep = False + all2all_backend = None from vllm.config import get_current_vllm_config config = get_current_vllm_config() @@ -119,9 +120,11 @@ class DeviceCommunicatorBase: # where all data parallel ranks execute forward together), # we initialize the all2all manager used in expert parallel. use_ep = config.parallel_config.data_parallel_size > 1 + all2all_backend = config.parallel_config.all2all_backend self.is_ep_communicator = "ep" in unique_name self.use_all2all = self.is_ep_communicator and use_ep + self.all2all_backend = all2all_backend self.all2all_manager: All2AllManagerBase | None = None def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 39b02311fe873..971a87f57dbb9 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -91,33 +91,32 @@ class CudaCommunicator(DeviceCommunicatorBase): self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) if self.use_all2all: - all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend == "naive": + if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) - elif all2all_backend == "allgather_reducescatter": + elif self.all2all_backend == "allgather_reducescatter": from .all2all import AgRsAll2AllManager self.all2all_manager = AgRsAll2AllManager(self.cpu_group) - elif all2all_backend == "pplx": + elif self.all2all_backend == "pplx": from .all2all import PPLXAll2AllManager self.all2all_manager = PPLXAll2AllManager(self.cpu_group) - elif all2all_backend == "deepep_high_throughput": + elif self.all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) - elif all2all_backend == "deepep_low_latency": + elif self.all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) - elif all2all_backend == "flashinfer_all2allv": + elif self.all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) else: - raise ValueError(f"Unknown all2all backend: {all2all_backend}") + raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") if is_global_first_rank(): logger.info( diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 83e336511059b..ad61fdfb8ea52 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -6,7 +6,6 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -import vllm.envs as envs from vllm.logger import init_logger from .base_device_communicator import DeviceCommunicatorBase @@ -24,15 +23,14 @@ class XpuCommunicator(DeviceCommunicatorBase): ): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: - all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend != "naive": + if self.all2all_backend != "naive": logger.warning( - "`%s` all2all manager is not supported on XPU." + "`%s` all2all manager is not supported on XPU. " "Falling back to `naive` all2all manager for XPU.", - all2all_backend, + self.all2all_backend, ) - all2all_backend = "naive" - if all2all_backend == "naive": + self.all2all_backend = "naive" + if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0eb3c2213384..09c8b4ca02c57 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -371,6 +371,7 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + all2all_backend: str | None = ParallelConfig.all2all_backend enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold @@ -763,6 +764,9 @@ class EngineArgs: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] ) + parallel_group.add_argument( + "--all2all-backend", **parallel_kwargs["all2all_backend"] + ) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) parallel_group.add_argument( "--dbo-decode-token-threshold", @@ -1461,6 +1465,7 @@ class EngineArgs: data_parallel_backend=self.data_parallel_backend, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, + all2all_backend=self.all2all_backend, enable_dbo=self.enable_dbo, dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 377116124522c..38ea6acc0fc50 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -641,6 +641,7 @@ class FusedMoEParallelConfig: ep_rank: int use_ep: bool # whether to use EP or not + all2all_backend: str # all2all backend for MoE communication @property def use_all2all_kernels(self): @@ -648,21 +649,18 @@ class FusedMoEParallelConfig: @property def use_pplx_kernels(self): - return self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx" + return self.use_all2all_kernels and self.all2all_backend == "pplx" @property def use_deepep_ht_kernels(self): return ( self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + and self.all2all_backend == "deepep_high_throughput" ) @property def use_deepep_ll_kernels(self): - return ( - self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" - ) + return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" @staticmethod def make( @@ -762,6 +760,7 @@ class FusedMoEParallelConfig: ep_size=1, ep_rank=0, use_ep=False, + all2all_backend=vllm_parallel_config.all2all_backend, ) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -777,6 +776,7 @@ class FusedMoEParallelConfig: ep_size=ep_size, ep_rank=ep_rank, use_ep=True, + all2all_backend=vllm_parallel_config.all2all_backend, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 1c6b5de83b2ba..ddb74a27dc122 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -58,7 +58,7 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize( ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 - enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv" + enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv" return create_flashinfer_prepare_finalize( use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv ) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b51421b6a32d3..0252c3acb08c1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -192,7 +192,7 @@ class CudaPlatformBase(Platform): compilation_config = vllm_config.compilation_config if ( - envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + parallel_config.all2all_backend == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 and compilation_config.cudagraph_mode != CUDAGraphMode.NONE ): @@ -204,7 +204,7 @@ class CudaPlatformBase(Platform): "kernels are optimized for prefill and are incompatible with " "CUDA Graphs. " "In order to use CUDA Graphs for decode-optimized workloads, " - "set VLLM_ALL2ALL_BACKEND to another option, such as " + "use --all2all-backend with another option, such as " "deepep_low_latency, pplx, or allgather_reducescatter." ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index e617abf6b2c7d..159b779111c44 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -356,9 +356,10 @@ class CoreEngineActorManager: ) device_str = current_platform.ray_device_key + all2all_backend = vllm_config.parallel_config.all2all_backend if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and ( - envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" - or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" + all2all_backend == "deepep_high_throughput" + or all2all_backend == "deepep_low_latency" ): raise ValueError( "DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) "