[Fix] uniform decode batch check (#30747)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang 2025-12-17 03:58:43 -08:00 committed by GitHub
parent 6482e3895b
commit 6e9dbcc50e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 8 deletions

View File

@ -1110,3 +1110,87 @@ def test_hybrid_cache_integration(model_runner, dist_init):
runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id)
def test_is_uniform_decode() -> None:
# Normal
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
)
# Spec decoding
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=4,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=7,
)
# Force uniform decode
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=True,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=False,
)

View File

@ -2777,6 +2777,27 @@ class GPUModelRunner(
**model_kwargs,
)
@staticmethod
def _is_uniform_decode(
max_num_scheduled_tokens: int,
uniform_decode_query_len: int,
num_tokens: int,
num_reqs: int,
force_uniform_decode: bool | None = None,
) -> bool:
"""
Checks if it's a decode batch with same amount scheduled tokens
across all requests.
"""
return (
(
(max_num_scheduled_tokens == uniform_decode_query_len)
and (num_tokens == max_num_scheduled_tokens * num_reqs)
)
if force_uniform_decode is None
else force_uniform_decode
)
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
@ -2798,14 +2819,12 @@ class GPUModelRunner(
torch.Tensor | None,
CUDAGraphStat | None,
]:
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
uniform_decode = (
(
(max_num_scheduled_tokens == self.uniform_decode_query_len)
and (num_tokens_padded == max_num_scheduled_tokens * num_reqs)
)
if force_uniform_decode is None
else force_uniform_decode
uniform_decode = self._is_uniform_decode(
max_num_scheduled_tokens=max_num_scheduled_tokens,
uniform_decode_query_len=self.uniform_decode_query_len,
num_tokens=num_tokens,
num_reqs=num_reqs,
force_uniform_decode=force_uniform_decode,
)
# Encoder-decoder models only support CG for decoder_step > 0 (no enc_output
# is present). Also, chunked-prefill is disabled, so batch are uniform.
@ -2819,6 +2838,7 @@ class GPUModelRunner(
else force_has_lora
)
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
dispatch_cudagraph = (
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens,
@ -2834,6 +2854,15 @@ class GPUModelRunner(
num_tokens_padded, use_cascade_attn or has_encoder_output
)
num_tokens_padded = batch_descriptor.num_tokens
if self.compilation_config.pass_config.enable_sp:
assert (
batch_descriptor.num_tokens
% self.vllm_config.parallel_config.tensor_parallel_size
== 0
), (
"Sequence parallelism requires num_tokens to be "
"a multiple of tensor parallel size"
)
# Extra coordination when running data-parallel since we need to coordinate
# across ranks