should_ubatch improvements

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-08 13:07:57 +00:00
parent 1a0e7110dd
commit 716b03277e

View File

@ -57,9 +57,11 @@ class DPMetadata:
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group)
# If there's an incorrect ordering of ARs across DP ranks, this tensor
# can end up containing the number of padded tokens for a DP rank
assert torch.all(should_ubatch_tensor <= 1)
# This function uses the same ProcessGroup for all reduce as
# num_tokens_across_dp. If there's an incorrect ordering of ARs
# across DP ranks, this tensor can end up containing the number
# of padded tokens for a DP rank.
assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1))
result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
return result