mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
Fix Flashinfer Allreduce+Norm enable disable calculation based on fi_allreduce_fusion_max_token_num (#21325)
Signed-off-by: XIn Li <xinli@nvidia.com>
This commit is contained in:
parent
35366ae57c
commit
ae268b6326
@ -159,6 +159,9 @@ if flashinfer_comm is not None:
|
||||
6: MiB // 2, # 512KB
|
||||
8: MiB // 2, # 512KB
|
||||
}
|
||||
# opt for a more conservative default value
|
||||
# when world size is not in _FI_MAX_SIZES
|
||||
_DEFAULT_FI_MAX_SIZE = MiB // 2
|
||||
|
||||
def call_trtllm_fused_allreduce_norm(
|
||||
allreduce_in: torch.Tensor,
|
||||
@ -173,12 +176,16 @@ if flashinfer_comm is not None:
|
||||
max_token_num: int,
|
||||
norm_out: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
|
||||
1] * allreduce_in.element_size() <= min(
|
||||
_FI_MAX_SIZES[world_size],
|
||||
max_token_num * allreduce_in.shape[0] *
|
||||
allreduce_in.element_size(),
|
||||
)
|
||||
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
max_fusion_size = max_token_num * hidden_size * element_size
|
||||
use_flashinfer = current_tensor_size <= min(
|
||||
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
||||
max_fusion_size,
|
||||
)
|
||||
|
||||
if use_flashinfer:
|
||||
assert (_FI_WORKSPACE_TENSOR is not None
|
||||
), "Flashinfer must be enabled when using flashinfer"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user