Fix Flashinfer Allreduce+Norm enable disable calculation based on fi_allreduce_fusion_max_token_num (#21325)

Signed-off-by: XIn Li <xinli@nvidia.com>
This commit is contained in:
Xin Li 2025-07-22 15:42:31 -04:00 committed by GitHub
parent 35366ae57c
commit ae268b6326
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,6 +159,9 @@ if flashinfer_comm is not None:
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
@ -173,12 +176,16 @@ if flashinfer_comm is not None:
max_token_num: int,
norm_out: Optional[torch.Tensor] = None,
) -> None:
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
1] * allreduce_in.element_size() <= min(
_FI_MAX_SIZES[world_size],
max_token_num * allreduce_in.shape[0] *
allreduce_in.element_size(),
)
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_fusion_size = max_token_num * hidden_size * element_size
use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer"