mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 17:45:17 +08:00
Remove all2all backend envvar (#30363)
Signed-off-by: Elizabeth Thomas <email2eliza@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
97000a2be7
commit
41b6f9200f
@ -44,10 +44,10 @@ trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--enable-eplb \
|
||||
--all2all-backend $BACK \
|
||||
--eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \
|
||||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
|
||||
--data-parallel-size ${DATA_PARALLEL_SIZE} \
|
||||
|
||||
@ -43,12 +43,12 @@ trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--all2all-backend $BACK \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
|
||||
--trust-remote-code \
|
||||
|
||||
@ -1497,7 +1497,7 @@ steps:
|
||||
- "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- HIP_VISIBLE_DEVICES=0,1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- HIP_VISIBLE_DEVICES=0,1 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
##### B200 test #####
|
||||
|
||||
@ -1331,7 +1331,7 @@ steps:
|
||||
- "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
##### B200 test #####
|
||||
|
||||
@ -145,7 +145,7 @@ steps:
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 --all2all-backend deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
- label: Distributed Tests (2 GPUs)(B200)
|
||||
|
||||
@ -16,7 +16,7 @@ Async backends support the use of DBO (Dual Batch Overlap) and shared expert ove
|
||||
|
||||
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
|
||||
|
||||
Unless otherwise specified, backends are controlled via `VLLM_ALL2ALL_BACKEND`. All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
|
||||
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
|
||||
|
||||
<style>
|
||||
td {
|
||||
|
||||
@ -55,7 +55,6 @@ done
|
||||
echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS"
|
||||
|
||||
export RAY_DEDUP_LOGS=0
|
||||
export VLLM_ALL2ALL_BACKEND="pplx"
|
||||
export VLLM_USE_DEEP_GEMM=1
|
||||
|
||||
vllm serve $MODEL_NAME \
|
||||
@ -65,6 +64,7 @@ vllm serve $MODEL_NAME \
|
||||
--enforce-eager \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--all2all-backend pplx \
|
||||
--num-redundant-experts $REDUNDANT_EXPERTS \
|
||||
--trust-remote-code \
|
||||
--host $HOST \
|
||||
|
||||
@ -49,7 +49,10 @@ def _create_vllm_config(
|
||||
mock_config.lora_config = None
|
||||
# Mimic the behavior of VllmConfig.__post_init__()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
compilation_config.set_splitting_ops_for_v1(
|
||||
all2all_backend=mock_config.parallel_config.all2all_backend,
|
||||
data_parallel_size=mock_config.parallel_config.data_parallel_size,
|
||||
)
|
||||
|
||||
# mimic VllmConfig.__post_init__
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
|
||||
@ -899,7 +899,7 @@ class CompilationConfig:
|
||||
self.compute_bs_to_padded_graph_size()
|
||||
|
||||
def set_splitting_ops_for_v1(
|
||||
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
|
||||
self, all2all_backend: str, data_parallel_size: int = 1
|
||||
):
|
||||
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
|
||||
# which currently only supports sequence parallelism in eager mode.
|
||||
@ -956,11 +956,9 @@ class CompilationConfig:
|
||||
self.splitting_ops = []
|
||||
|
||||
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
|
||||
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
|
||||
dp_size = data_parallel_size if data_parallel_size is not None else 1
|
||||
if (
|
||||
backend == "deepep_high_throughput"
|
||||
and dp_size > 1
|
||||
all2all_backend == "deepep_high_throughput"
|
||||
and data_parallel_size > 1
|
||||
and self.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
# TODO: Piecewise Cuda graph might be enabled
|
||||
|
||||
@ -36,6 +36,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
|
||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||
DataParallelBackend = Literal["ray", "mp"]
|
||||
EPLBPolicyOption = Literal["default"]
|
||||
All2AllBackend = Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@ -126,24 +134,14 @@ class ParallelConfig:
|
||||
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
|
||||
will have experts [1, 3]. This strategy can help improve load balancing
|
||||
for grouped expert models with no redundant experts."""
|
||||
all2all_backend: (
|
||||
Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
"""All2All backend for MoE expert parallel communication. If not set, uses
|
||||
the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
|
||||
- "naive": Naive all2all implementation using broadcasts
|
||||
- "allgather_reducescatter": All2all based on allgather and reducescatter
|
||||
- "pplx": Use pplx kernels
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels
|
||||
- "deepep_low_latency": Use deepep low-latency kernels
|
||||
all2all_backend: All2AllBackend = "allgather_reducescatter"
|
||||
"""All2All backend for MoE expert parallel communication. Available options:
|
||||
|
||||
- "naive": Naive all2all implementation using broadcasts\n
|
||||
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
|
||||
- "pplx": Use pplx kernels\n
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels\n
|
||||
- "deepep_low_latency": Use deepep low-latency kernels\n
|
||||
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
|
||||
|
||||
max_parallel_loading_workers: int | None = None
|
||||
@ -495,20 +493,17 @@ class ParallelConfig:
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
# Explicitly include backend affecting env factor as before
|
||||
factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND)
|
||||
return hash_factors(factors)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set all2all_backend from env var if not specified, with deprecation warning
|
||||
if self.all2all_backend is None:
|
||||
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
|
||||
logger.warning_once(
|
||||
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
|
||||
"will be removed in v0.15.0. Please use the "
|
||||
"--all2all-backend command-line argument instead."
|
||||
)
|
||||
self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
|
||||
logger.warning_once(
|
||||
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
|
||||
"will be removed in a future release. Please use the "
|
||||
"--all2all-backend command-line argument instead."
|
||||
)
|
||||
|
||||
# Continue with the rest of the initialization
|
||||
self.world_size = (
|
||||
|
||||
@ -408,7 +408,7 @@ class EngineArgs:
|
||||
data_parallel_external_lb: bool = False
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
all2all_backend: str | None = ParallelConfig.all2all_backend
|
||||
all2all_backend: str = ParallelConfig.all2all_backend
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
ubatch_size: int = ParallelConfig.ubatch_size
|
||||
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
||||
|
||||
@ -1263,7 +1263,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int(
|
||||
os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998")
|
||||
),
|
||||
# all2all backend for vllm's expert parallel communication
|
||||
# [DEPRECATED - will be removed in v0.15.0] all2all backend for vllm's
|
||||
# expert parallel communication. Use --all2all-backend CLI argument instead.
|
||||
# Available options:
|
||||
# - "naive": naive all2all implementation using broadcasts
|
||||
# - "allgather_reducescatter": all2all implementation based on allgather and
|
||||
@ -1274,7 +1275,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
|
||||
"VLLM_ALL2ALL_BACKEND": env_with_choices(
|
||||
"VLLM_ALL2ALL_BACKEND",
|
||||
"allgather_reducescatter",
|
||||
None,
|
||||
[
|
||||
"naive",
|
||||
"pplx",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user