diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 2cdd62c72d581..e87172bc3d79c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -57,9 +57,11 @@ class DPMetadata: from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group) - # If there's an incorrect ordering of ARs across DP ranks, this tensor - # can end up containing the number of padded tokens for a DP rank - assert torch.all(should_ubatch_tensor <= 1) + # This function uses the same ProcessGroup for all reduce as + # num_tokens_across_dp. If there's an incorrect ordering of ARs + # across DP ranks, this tensor can end up containing the number + # of padded tokens for a DP rank. + assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1)) result: bool = bool(torch.all(should_ubatch_tensor == 1).item()) return result