mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:35:26 +08:00
Fix Flashinfer Allreduce+Norm enable disable calculation based on fi_allreduce_fusion_max_token_num (#21325)
Signed-off-by: XIn Li <xinli@nvidia.com>
This commit is contained in:
parent
35366ae57c
commit
ae268b6326
@ -159,6 +159,9 @@ if flashinfer_comm is not None:
|
|||||||
6: MiB // 2, # 512KB
|
6: MiB // 2, # 512KB
|
||||||
8: MiB // 2, # 512KB
|
8: MiB // 2, # 512KB
|
||||||
}
|
}
|
||||||
|
# opt for a more conservative default value
|
||||||
|
# when world size is not in _FI_MAX_SIZES
|
||||||
|
_DEFAULT_FI_MAX_SIZE = MiB // 2
|
||||||
|
|
||||||
def call_trtllm_fused_allreduce_norm(
|
def call_trtllm_fused_allreduce_norm(
|
||||||
allreduce_in: torch.Tensor,
|
allreduce_in: torch.Tensor,
|
||||||
@ -173,12 +176,16 @@ if flashinfer_comm is not None:
|
|||||||
max_token_num: int,
|
max_token_num: int,
|
||||||
norm_out: Optional[torch.Tensor] = None,
|
norm_out: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
|
|
||||||
1] * allreduce_in.element_size() <= min(
|
num_tokens, hidden_size = allreduce_in.shape
|
||||||
_FI_MAX_SIZES[world_size],
|
element_size = allreduce_in.element_size()
|
||||||
max_token_num * allreduce_in.shape[0] *
|
current_tensor_size = num_tokens * hidden_size * element_size
|
||||||
allreduce_in.element_size(),
|
max_fusion_size = max_token_num * hidden_size * element_size
|
||||||
)
|
use_flashinfer = current_tensor_size <= min(
|
||||||
|
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
||||||
|
max_fusion_size,
|
||||||
|
)
|
||||||
|
|
||||||
if use_flashinfer:
|
if use_flashinfer:
|
||||||
assert (_FI_WORKSPACE_TENSOR is not None
|
assert (_FI_WORKSPACE_TENSOR is not None
|
||||||
), "Flashinfer must be enabled when using flashinfer"
|
), "Flashinfer must be enabled when using flashinfer"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user