[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds (#24248)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Ilya Markov 2025-11-11 00:33:11 +01:00 committed by GitHub
parent 021143561f
commit d17ecc6b19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1284 additions and 83 deletions

View File

@ -463,8 +463,8 @@ steps:
- pytest -v -s compile/test_multimodal_compile.py
- pytest -v -s compile/piecewise/
- label: PyTorch Fullgraph Test # 22min
timeout_in_minutes: 35
- label: PyTorch Fullgraph Test # 27min
timeout_in_minutes: 40
mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies:

File diff suppressed because it is too large Load Diff

View File

@ -71,6 +71,13 @@ if current_platform.is_cuda():
attention_fusions=0,
allreduce_fusions=65,
),
ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=97,
),
]
elif current_platform.is_rocm():

View File

@ -9,7 +9,6 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
@ -450,34 +449,41 @@ class AsyncTPPass(VllmPatternMatcherPass):
logger.debug("Replaced %s patterns", self.matched_count)
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
90: {
2: 64, # 64MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 64, # 64MB
4: 32, # 32MB
8: 1, # 1MB
},
}
# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
90: {
2: 32, # 32MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 32, # 32MB
4: 4, # 4MB
8: 1, # 1MB
},
}
if flashinfer_comm is not None:
_FI_WORKSPACE_TENSOR = None
MiB = 1024 * 1024
# Max size of the input tensor per world size
# to use flashinfer fused allreduce
_FI_MAX_SIZES = {
2: 64 * MiB, # 64MB
4: MiB, # 1MB
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
try:
_FI_MAX_SIZES.update(
{
int(k): int(float(v) * MiB)
for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
}
)
except Exception as e:
raise ValueError(
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e)
) from e
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
@ -491,7 +497,6 @@ if flashinfer_comm is not None:
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
fuse_rms_quant: bool,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
@ -500,12 +505,20 @@ if flashinfer_comm is not None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_fusion_size = max_token_num * hidden_size * element_size
use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
if use_flashinfer:
if num_tokens <= max_token_num:
device_capability = current_platform.get_device_capability().to_int()
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, {}
).get(world_size, None)
# Use one shot if no max size for one shot is specified
use_oneshot = (
max_one_shot_size_mb is None
or current_tensor_size <= max_one_shot_size_mb * MiB
)
assert _FI_WORKSPACE_TENSOR is not None, (
"Flashinfer must be enabled when using flashinfer"
)
@ -532,7 +545,7 @@ if flashinfer_comm is not None:
hidden_dim=allreduce_in.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
launch_with_pdl=launch_with_pdl,
use_oneshot=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=pattern_code,
@ -545,7 +558,7 @@ if flashinfer_comm is not None:
)
else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
if scale_factor is not None and scale_out is None and fuse_rms_quant:
if scale_factor is not None and scale_out is None:
# Do fused rms norm static fp8 quant fused op
if norm_out is None:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
@ -568,15 +581,10 @@ if flashinfer_comm is not None:
norm_out = allreduce_out
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
if scale_factor is not None:
if scale_out is not None:
torch.ops._C.scaled_fp4_quant(
quant_out, norm_out, scale_out, scale_factor
)
else:
torch.ops._C.static_scaled_fp8_quant(
quant_out, norm_out, scale_factor
)
if scale_factor is not None and scale_out is not None:
torch.ops._C.scaled_fp4_quant(
quant_out, norm_out, scale_out, scale_factor
)
if scale_factor is None or norm_out is not None:
# we need to return allreduce output
# in cases of non quant fused AR + RMS norm
@ -595,7 +603,6 @@ if flashinfer_comm is not None:
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
fuse_rms_quant: bool,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
@ -629,7 +636,6 @@ class FlashInferFusedAllReduceParams:
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
fuse_rms_quant: bool = False,
):
self.rank = rank
self.world_size = world_size
@ -637,9 +643,7 @@ class FlashInferFusedAllReduceParams:
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.use_oneshot = False
self.max_token_num = max_token_num
self.fuse_rms_quant = fuse_rms_quant
def get_trtllm_fused_allreduce_kwargs(self):
return {
@ -649,7 +653,6 @@ class FlashInferFusedAllReduceParams:
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
"fuse_rms_quant": self.fuse_rms_quant,
}
@ -1119,23 +1122,35 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
"skipping allreduce fusion pass"
)
return
# Check if the world size is supported
if self.tp_size not in _FI_MAX_SIZES:
max_size = config.compilation_config.pass_config.flashinfer_max_size(
self.tp_size
)
if max_size is None:
# Flashinfer doesn't support current world size
logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s",
self.tp_size,
)
return
max_num_token = min(
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE)
// (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num,
element_size = 4 if use_fp32_lamport else 2
self.max_token_num = max_size // (self.hidden_dim * element_size)
# take the min to save workspace size and we'll never use more
# than max_num_batched_tokens anyways
self.max_token_num = min(
self.max_token_num, config.scheduler_config.max_num_batched_tokens
)
logger.debug_once(
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
"Maximal number of tokens used by "
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
scope="global",
)
self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=max_num_token,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
group=self.group,
use_fp32_lamport=use_fp32_lamport,
@ -1148,10 +1163,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
rank=rank,
world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport,
max_token_num=max_num_token,
# fuse rms norm static fp8 quant fused op
# in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion,
max_token_num=self.max_token_num,
)
self.register_patterns()

View File

@ -111,11 +111,52 @@ class PassConfig:
"""Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False
"""Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_token_num: int = 16384
"""Max number of tokens to used in flashinfer allreduce fusion."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
float in MB.
Unspecified will fallback to default values
which are compute capability and world size dependent.
FI_ALLREDUCE_FUSION_MAX_SIZE_MB = {
90: {
2: 64, # 64MB
4: 2, # 2MB
8: 1, # 1MB
},
100: {
2: 64, # 64MB
4: 32, # 32MB
8: 1, # 1MB
},
}, where key is the device capability"""
# TODO(luka) better pass enabling system.
def flashinfer_max_size(self, world_size: int) -> int | None:
"""
Returns the max communication size in bytes for flashinfer
allreduce fusion for the given world size. Returns None if world size
is not supported by configs as it's not supported by flashinfer.
"""
MiB = 1024 * 1024
max_size_mb = self.fi_allreduce_fusion_max_size_mb
if max_size_mb is None:
max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)
return int(max_size_mb * MiB) if max_size_mb is not None else None
@staticmethod
def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]:
from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB
from vllm.platforms import current_platform
if not current_platform.is_cuda():
return {}
return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get(
current_platform.get_device_capability().to_int(), {}
)
def uuid(self):
"""
Produces a hash unique to the pass configuration.
@ -136,6 +177,11 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Attention + quant (fp8) fusion might not work"
)
if self.enable_fi_allreduce_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
@config

View File

@ -2356,6 +2356,16 @@ class FusedMoE(CustomOp):
value=0.0,
)
def reduce_output(states: torch.Tensor) -> torch.Tensor:
if (
not self.is_sequence_parallel
and not self.use_dp_chunking
and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)
):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return states
if self.shared_experts is None:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@ -2366,7 +2376,14 @@ class FusedMoE(CustomOp):
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name
)
return fused_output[..., :og_hidden_states]
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(fused_output, tuple)
fused_output, zero_expert_result = fused_output
return (reduce_output(fused_output) + zero_expert_result)[
..., :og_hidden_states
]
else:
return reduce_output(fused_output)[..., :og_hidden_states]
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@ -2379,8 +2396,8 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, self.layer_name
)
return (
shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states],
reduce_output(shared_output)[..., :og_hidden_states],
reduce_output(fused_output)[..., :og_hidden_states],
)
def forward_cuda(
@ -2667,31 +2684,21 @@ class FusedMoE(CustomOp):
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(
states: torch.Tensor, do_combine: bool = True
) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel)
if (
not self.is_sequence_parallel
and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)
):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return states
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
return (combine_output(final_hidden_states), zero_expert_result)
else:
return reduce_output(final_hidden_states)
return combine_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(