Remove all2all backend envvar (#30363)

Signed-off-by: Elizabeth Thomas <email2eliza@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Elizabeth Thomas 2025-12-18 13:46:28 -06:00 committed by GitHub
parent 97000a2be7
commit 41b6f9200f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 40 additions and 43 deletions

View File

@ -44,10 +44,10 @@ trap cleanup EXIT
for BACK in "${BACKENDS[@]}"; do
VLLM_DEEP_GEMM_WARMUP=skip \
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
--enable-eplb \
--all2all-backend $BACK \
--eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--data-parallel-size ${DATA_PARALLEL_SIZE} \

View File

@ -43,12 +43,12 @@ trap cleanup EXIT
for BACK in "${BACKENDS[@]}"; do
VLLM_DEEP_GEMM_WARMUP=skip \
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
--tensor-parallel-size 4 \
--enable-expert-parallel \
--enable-eplb \
--all2all-backend $BACK \
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
--trust-remote-code \

View File

@ -1497,7 +1497,7 @@ steps:
- "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- HIP_VISIBLE_DEVICES=0,1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- HIP_VISIBLE_DEVICES=0,1 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
- pytest -v -s tests/v1/distributed/test_dbo.py
##### B200 test #####

View File

@ -1331,7 +1331,7 @@ steps:
- "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- CUDA_VISIBLE_DEVICES=1,2 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
- pytest -v -s tests/v1/distributed/test_dbo.py
##### B200 test #####

View File

@ -145,7 +145,7 @@ steps:
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- CUDA_VISIBLE_DEVICES=1,2 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
- pytest -v -s tests/v1/distributed/test_dbo.py
- label: Distributed Tests (2 GPUs)(B200)

View File

@ -16,7 +16,7 @@ Async backends support the use of DBO (Dual Batch Overlap) and shared expert ove
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
Unless otherwise specified, backends are controlled via `VLLM_ALL2ALL_BACKEND`. All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
<style>
td {

View File

@ -55,7 +55,6 @@ done
echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS"
export RAY_DEDUP_LOGS=0
export VLLM_ALL2ALL_BACKEND="pplx"
export VLLM_USE_DEEP_GEMM=1
vllm serve $MODEL_NAME \
@ -65,6 +64,7 @@ vllm serve $MODEL_NAME \
--enforce-eager \
--enable-expert-parallel \
--enable-eplb \
--all2all-backend pplx \
--num-redundant-experts $REDUNDANT_EXPERTS \
--trust-remote-code \
--host $HOST \

View File

@ -49,7 +49,10 @@ def _create_vllm_config(
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()
compilation_config.set_splitting_ops_for_v1(
all2all_backend=mock_config.parallel_config.all2all_backend,
data_parallel_size=mock_config.parallel_config.data_parallel_size,
)
# mimic VllmConfig.__post_init__
if compilation_config.cudagraph_capture_sizes:

View File

@ -899,7 +899,7 @@ class CompilationConfig:
self.compute_bs_to_padded_graph_size()
def set_splitting_ops_for_v1(
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
self, all2all_backend: str, data_parallel_size: int = 1
):
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode.
@ -956,11 +956,9 @@ class CompilationConfig:
self.splitting_ops = []
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
if (
backend == "deepep_high_throughput"
and dp_size > 1
all2all_backend == "deepep_high_throughput"
and data_parallel_size > 1
and self.cudagraph_mode != CUDAGraphMode.NONE
):
# TODO: Piecewise Cuda graph might be enabled

View File

@ -36,6 +36,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"]
All2AllBackend = Literal[
"naive",
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
"flashinfer_all2allv",
]
@config
@ -126,24 +134,14 @@ class ParallelConfig:
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
will have experts [1, 3]. This strategy can help improve load balancing
for grouped expert models with no redundant experts."""
all2all_backend: (
Literal[
"naive",
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
"flashinfer_all2allv",
]
| None
) = None
"""All2All backend for MoE expert parallel communication. If not set, uses
the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
- "naive": Naive all2all implementation using broadcasts
- "allgather_reducescatter": All2all based on allgather and reducescatter
- "pplx": Use pplx kernels
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
all2all_backend: All2AllBackend = "allgather_reducescatter"
"""All2All backend for MoE expert parallel communication. Available options:
- "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
- "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers: int | None = None
@ -495,20 +493,17 @@ class ParallelConfig:
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
# Explicitly include backend affecting env factor as before
factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND)
return hash_factors(factors)
def __post_init__(self) -> None:
# Set all2all_backend from env var if not specified, with deprecation warning
if self.all2all_backend is None:
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
logger.warning_once(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in v0.15.0. Please use the "
"--all2all-backend command-line argument instead."
)
self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
logger.warning_once(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in a future release. Please use the "
"--all2all-backend command-line argument instead."
)
# Continue with the rest of the initialization
self.world_size = (

View File

@ -408,7 +408,7 @@ class EngineArgs:
data_parallel_external_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
all2all_backend: str | None = ParallelConfig.all2all_backend
all2all_backend: str = ParallelConfig.all2all_backend
enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold

View File

@ -1263,7 +1263,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int(
os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998")
),
# all2all backend for vllm's expert parallel communication
# [DEPRECATED - will be removed in v0.15.0] all2all backend for vllm's
# expert parallel communication. Use --all2all-backend CLI argument instead.
# Available options:
# - "naive": naive all2all implementation using broadcasts
# - "allgather_reducescatter": all2all implementation based on allgather and
@ -1274,7 +1275,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
"VLLM_ALL2ALL_BACKEND": env_with_choices(
"VLLM_ALL2ALL_BACKEND",
"allgather_reducescatter",
None,
[
"naive",
"pplx",