mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds (#24248)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
021143561f
commit
d17ecc6b19
@ -463,8 +463,8 @@ steps:
|
||||
- pytest -v -s compile/test_multimodal_compile.py
|
||||
- pytest -v -s compile/piecewise/
|
||||
|
||||
- label: PyTorch Fullgraph Test # 22min
|
||||
timeout_in_minutes: 35
|
||||
- label: PyTorch Fullgraph Test # 27min
|
||||
timeout_in_minutes: 40
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
|
||||
1129
benchmarks/kernels/benchmark_fused_collective.py
Normal file
1129
benchmarks/kernels/benchmark_fused_collective.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -71,6 +71,13 @@ if current_platform.is_cuda():
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=97,
|
||||
),
|
||||
]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
|
||||
@ -9,7 +9,6 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -450,34 +449,41 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 64, # 64MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 64, # 64MB
|
||||
4: 32, # 32MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer one shot fused allreduce
|
||||
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
|
||||
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 32, # 32MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 32, # 32MB
|
||||
4: 4, # 4MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
_FI_WORKSPACE_TENSOR = None
|
||||
|
||||
MiB = 1024 * 1024
|
||||
# Max size of the input tensor per world size
|
||||
# to use flashinfer fused allreduce
|
||||
_FI_MAX_SIZES = {
|
||||
2: 64 * MiB, # 64MB
|
||||
4: MiB, # 1MB
|
||||
6: MiB // 2, # 512KB
|
||||
8: MiB // 2, # 512KB
|
||||
}
|
||||
|
||||
try:
|
||||
_FI_MAX_SIZES.update(
|
||||
{
|
||||
int(k): int(float(v) * MiB)
|
||||
for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e)
|
||||
) from e
|
||||
|
||||
# opt for a more conservative default value
|
||||
# when world size is not in _FI_MAX_SIZES
|
||||
_DEFAULT_FI_MAX_SIZE = MiB // 2
|
||||
|
||||
def call_trtllm_fused_allreduce_norm(
|
||||
allreduce_in: torch.Tensor,
|
||||
@ -491,7 +497,6 @@ if flashinfer_comm is not None:
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
fuse_rms_quant: bool,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
@ -500,12 +505,20 @@ if flashinfer_comm is not None:
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
max_fusion_size = max_token_num * hidden_size * element_size
|
||||
use_flashinfer = current_tensor_size <= min(
|
||||
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
||||
max_fusion_size,
|
||||
)
|
||||
if use_flashinfer:
|
||||
|
||||
if num_tokens <= max_token_num:
|
||||
device_capability = current_platform.get_device_capability().to_int()
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, {}
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size for one shot is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size_mb is None
|
||||
or current_tensor_size <= max_one_shot_size_mb * MiB
|
||||
)
|
||||
|
||||
assert _FI_WORKSPACE_TENSOR is not None, (
|
||||
"Flashinfer must be enabled when using flashinfer"
|
||||
)
|
||||
@ -532,7 +545,7 @@ if flashinfer_comm is not None:
|
||||
hidden_dim=allreduce_in.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
use_oneshot=True,
|
||||
use_oneshot=use_oneshot,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=pattern_code,
|
||||
@ -545,7 +558,7 @@ if flashinfer_comm is not None:
|
||||
)
|
||||
else:
|
||||
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
||||
if scale_factor is not None and scale_out is None and fuse_rms_quant:
|
||||
if scale_factor is not None and scale_out is None:
|
||||
# Do fused rms norm static fp8 quant fused op
|
||||
if norm_out is None:
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||
@ -568,15 +581,10 @@ if flashinfer_comm is not None:
|
||||
norm_out = allreduce_out
|
||||
else:
|
||||
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
|
||||
if scale_factor is not None:
|
||||
if scale_out is not None:
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
quant_out, norm_out, scale_out, scale_factor
|
||||
)
|
||||
else:
|
||||
torch.ops._C.static_scaled_fp8_quant(
|
||||
quant_out, norm_out, scale_factor
|
||||
)
|
||||
if scale_factor is not None and scale_out is not None:
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
quant_out, norm_out, scale_out, scale_factor
|
||||
)
|
||||
if scale_factor is None or norm_out is not None:
|
||||
# we need to return allreduce output
|
||||
# in cases of non quant fused AR + RMS norm
|
||||
@ -595,7 +603,6 @@ if flashinfer_comm is not None:
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
fuse_rms_quant: bool,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
@ -629,7 +636,6 @@ class FlashInferFusedAllReduceParams:
|
||||
world_size: int,
|
||||
use_fp32_lamport: bool = False,
|
||||
max_token_num: int = 1024,
|
||||
fuse_rms_quant: bool = False,
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
@ -637,9 +643,7 @@ class FlashInferFusedAllReduceParams:
|
||||
self.trigger_completion_at_end = True
|
||||
self.launch_with_pdl = True
|
||||
self.fp32_acc = True
|
||||
self.use_oneshot = False
|
||||
self.max_token_num = max_token_num
|
||||
self.fuse_rms_quant = fuse_rms_quant
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self):
|
||||
return {
|
||||
@ -649,7 +653,6 @@ class FlashInferFusedAllReduceParams:
|
||||
"trigger_completion_at_end": self.trigger_completion_at_end,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
"max_token_num": self.max_token_num,
|
||||
"fuse_rms_quant": self.fuse_rms_quant,
|
||||
}
|
||||
|
||||
|
||||
@ -1119,23 +1122,35 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
"skipping allreduce fusion pass"
|
||||
)
|
||||
return
|
||||
# Check if the world size is supported
|
||||
if self.tp_size not in _FI_MAX_SIZES:
|
||||
max_size = config.compilation_config.pass_config.flashinfer_max_size(
|
||||
self.tp_size
|
||||
)
|
||||
if max_size is None:
|
||||
# Flashinfer doesn't support current world size
|
||||
logger.warning(
|
||||
"Flashinfer allreduce fusion is not supported for world size %s",
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
max_num_token = min(
|
||||
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE)
|
||||
// (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
|
||||
config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num,
|
||||
element_size = 4 if use_fp32_lamport else 2
|
||||
self.max_token_num = max_size // (self.hidden_dim * element_size)
|
||||
# take the min to save workspace size and we'll never use more
|
||||
# than max_num_batched_tokens anyways
|
||||
self.max_token_num = min(
|
||||
self.max_token_num, config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
logger.debug_once(
|
||||
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
|
||||
"Maximal number of tokens used by "
|
||||
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
self.ipc_handles, workspace_tensor = (
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
tp_rank=rank,
|
||||
tp_size=self.tp_size,
|
||||
max_token_num=max_num_token,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
group=self.group,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
@ -1148,10 +1163,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
rank=rank,
|
||||
world_size=self.tp_size,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
max_token_num=max_num_token,
|
||||
# fuse rms norm static fp8 quant fused op
|
||||
# in fallback path, when we don't use flashinfer
|
||||
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion,
|
||||
max_token_num=self.max_token_num,
|
||||
)
|
||||
|
||||
self.register_patterns()
|
||||
|
||||
@ -111,11 +111,52 @@ class PassConfig:
|
||||
"""Whether to enable async TP."""
|
||||
enable_fi_allreduce_fusion: bool = False
|
||||
"""Whether to enable flashinfer allreduce fusion."""
|
||||
fi_allreduce_fusion_max_token_num: int = 16384
|
||||
"""Max number of tokens to used in flashinfer allreduce fusion."""
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
vllm should use flashinfer fused allreduce. Specified as a
|
||||
float in MB.
|
||||
Unspecified will fallback to default values
|
||||
which are compute capability and world size dependent.
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB = {
|
||||
90: {
|
||||
2: 64, # 64MB
|
||||
4: 2, # 2MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
100: {
|
||||
2: 64, # 64MB
|
||||
4: 32, # 32MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
def flashinfer_max_size(self, world_size: int) -> int | None:
|
||||
"""
|
||||
Returns the max communication size in bytes for flashinfer
|
||||
allreduce fusion for the given world size. Returns None if world size
|
||||
is not supported by configs as it's not supported by flashinfer.
|
||||
"""
|
||||
|
||||
MiB = 1024 * 1024
|
||||
max_size_mb = self.fi_allreduce_fusion_max_size_mb
|
||||
if max_size_mb is None:
|
||||
max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)
|
||||
|
||||
return int(max_size_mb * MiB) if max_size_mb is not None else None
|
||||
|
||||
@staticmethod
|
||||
def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]:
|
||||
from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return {}
|
||||
return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get(
|
||||
current_platform.get_device_capability().to_int(), {}
|
||||
)
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
@ -136,6 +177,11 @@ class PassConfig:
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Attention + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_fi_allreduce_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
)
|
||||
|
||||
|
||||
@config
|
||||
|
||||
@ -2356,6 +2356,16 @@ class FusedMoE(CustomOp):
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if (
|
||||
not self.is_sequence_parallel
|
||||
and not self.use_dp_chunking
|
||||
and self.reduce_results
|
||||
and (self.tp_size > 1 or self.ep_size > 1)
|
||||
):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
return states
|
||||
|
||||
if self.shared_experts is None:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
@ -2366,7 +2376,14 @@ class FusedMoE(CustomOp):
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, self.layer_name
|
||||
)
|
||||
return fused_output[..., :og_hidden_states]
|
||||
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(fused_output, tuple)
|
||||
fused_output, zero_expert_result = fused_output
|
||||
return (reduce_output(fused_output) + zero_expert_result)[
|
||||
..., :og_hidden_states
|
||||
]
|
||||
else:
|
||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
||||
else:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
@ -2379,8 +2396,8 @@ class FusedMoE(CustomOp):
|
||||
hidden_states, router_logits, self.layer_name
|
||||
)
|
||||
return (
|
||||
shared_output[..., :og_hidden_states],
|
||||
fused_output[..., :og_hidden_states],
|
||||
reduce_output(shared_output)[..., :og_hidden_states],
|
||||
reduce_output(fused_output)[..., :og_hidden_states],
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
@ -2667,31 +2684,21 @@ class FusedMoE(CustomOp):
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
def reduce_output(
|
||||
states: torch.Tensor, do_combine: bool = True
|
||||
) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(states, self.is_sequence_parallel)
|
||||
|
||||
if (
|
||||
not self.is_sequence_parallel
|
||||
and self.reduce_results
|
||||
and (self.tp_size > 1 or self.ep_size > 1)
|
||||
):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
final_hidden_states[0],
|
||||
combine_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
return (combine_output(final_hidden_states), zero_expert_result)
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
return combine_output(final_hidden_states)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user