diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e87172bc3d79c..76d4801e0c1bb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -61,6 +61,7 @@ class DPMetadata: # num_tokens_across_dp. If there's an incorrect ordering of ARs # across DP ranks, this tensor can end up containing the number # of padded tokens for a DP rank. + assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1)) result: bool = bool(torch.all(should_ubatch_tensor == 1).item())