From 7ea22e42d5f666a26b3ce4117724dadfdb4d3887 Mon Sep 17 00:00:00 2001 From: nvjullin Date: Tue, 26 Aug 2025 23:53:04 +0800 Subject: [PATCH] [Misc] Add override for allreduce fusion thresholds (#23639) Signed-off-by: Julien Lin --- vllm/compilation/collective_fusion.py | 13 +++++++++++++ vllm/envs.py | 11 +++++++++++ 2 files changed, 24 insertions(+) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index c44ac8e0aa7ea..0c545d8cffd24 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -401,6 +402,18 @@ if flashinfer_comm is not None: 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB } + + try: + _FI_MAX_SIZES.update({ + int(k): int(float(v) * MiB) + for k, v in + envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() + }) + except Exception as e: + raise ValueError( + "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + + str(e)) from e + # opt for a more conservative default value # when world size is not in _FI_MAX_SIZES _DEFAULT_FI_MAX_SIZE = MiB // 2 diff --git a/vllm/envs.py b/vllm/envs.py index 1c9c4cdde8001..66c7c2c7f2c4d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +import json import os import sys import tempfile @@ -1046,6 +1047,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + # Specifies the thresholds of the communicated tensor sizes under which + # vllm should use flashinfer fused allreduce. The variable should be a + # JSON with the following format: + # { : } + # Unspecified world sizes will fallback to + # { 2: 64, 4: 1, : 0.5 } + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": + lambda: json.loads(os.getenv( + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")), + # MoE routing strategy selector. # See `RoutingSimulator.get_available_strategies()` # for available # strategies.