mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 09:07:04 +08:00
should_ubatch improvements
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
1a0e7110dd
commit
716b03277e
@ -57,9 +57,11 @@ class DPMetadata:
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group)
|
||||
|
||||
# If there's an incorrect ordering of ARs across DP ranks, this tensor
|
||||
# can end up containing the number of padded tokens for a DP rank
|
||||
assert torch.all(should_ubatch_tensor <= 1)
|
||||
# This function uses the same ProcessGroup for all reduce as
|
||||
# num_tokens_across_dp. If there's an incorrect ordering of ARs
|
||||
# across DP ranks, this tensor can end up containing the number
|
||||
# of padded tokens for a DP rank.
|
||||
assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1))
|
||||
|
||||
result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
|
||||
return result
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user