mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 11:45:59 +08:00
[Fix] uniform decode batch check (#30747)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
parent
6482e3895b
commit
6e9dbcc50e
@ -1110,3 +1110,87 @@ def test_hybrid_cache_integration(model_runner, dist_init):
|
||||
runner._update_states(scheduler_output)
|
||||
assert _is_req_scheduled(runner, req_id)
|
||||
assert _is_req_state_block_table_match(runner, req_id)
|
||||
|
||||
|
||||
def test_is_uniform_decode() -> None:
|
||||
# Normal
|
||||
assert GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=1,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=16,
|
||||
)
|
||||
assert not GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=2,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=16,
|
||||
)
|
||||
assert not GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=1,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=15,
|
||||
)
|
||||
# Spec decoding
|
||||
assert GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=5,
|
||||
uniform_decode_query_len=5,
|
||||
num_tokens=30,
|
||||
num_reqs=6,
|
||||
)
|
||||
assert not GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=5,
|
||||
uniform_decode_query_len=4,
|
||||
num_tokens=30,
|
||||
num_reqs=6,
|
||||
)
|
||||
assert not GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=5,
|
||||
uniform_decode_query_len=5,
|
||||
num_tokens=30,
|
||||
num_reqs=7,
|
||||
)
|
||||
# Force uniform decode
|
||||
assert GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=1,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=16,
|
||||
force_uniform_decode=True,
|
||||
)
|
||||
assert GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=2,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=16,
|
||||
force_uniform_decode=True,
|
||||
)
|
||||
assert GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=1,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=15,
|
||||
force_uniform_decode=True,
|
||||
)
|
||||
assert not GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=1,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=16,
|
||||
force_uniform_decode=False,
|
||||
)
|
||||
assert not GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=2,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=16,
|
||||
force_uniform_decode=False,
|
||||
)
|
||||
assert not GPUModelRunner._is_uniform_decode(
|
||||
max_num_scheduled_tokens=1,
|
||||
uniform_decode_query_len=1,
|
||||
num_tokens=16,
|
||||
num_reqs=15,
|
||||
force_uniform_decode=False,
|
||||
)
|
||||
|
||||
@ -2777,6 +2777,27 @@ class GPUModelRunner(
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_uniform_decode(
|
||||
max_num_scheduled_tokens: int,
|
||||
uniform_decode_query_len: int,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
force_uniform_decode: bool | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Checks if it's a decode batch with same amount scheduled tokens
|
||||
across all requests.
|
||||
"""
|
||||
return (
|
||||
(
|
||||
(max_num_scheduled_tokens == uniform_decode_query_len)
|
||||
and (num_tokens == max_num_scheduled_tokens * num_reqs)
|
||||
)
|
||||
if force_uniform_decode is None
|
||||
else force_uniform_decode
|
||||
)
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
@ -2798,14 +2819,12 @@ class GPUModelRunner(
|
||||
torch.Tensor | None,
|
||||
CUDAGraphStat | None,
|
||||
]:
|
||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||
uniform_decode = (
|
||||
(
|
||||
(max_num_scheduled_tokens == self.uniform_decode_query_len)
|
||||
and (num_tokens_padded == max_num_scheduled_tokens * num_reqs)
|
||||
)
|
||||
if force_uniform_decode is None
|
||||
else force_uniform_decode
|
||||
uniform_decode = self._is_uniform_decode(
|
||||
max_num_scheduled_tokens=max_num_scheduled_tokens,
|
||||
uniform_decode_query_len=self.uniform_decode_query_len,
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
force_uniform_decode=force_uniform_decode,
|
||||
)
|
||||
# Encoder-decoder models only support CG for decoder_step > 0 (no enc_output
|
||||
# is present). Also, chunked-prefill is disabled, so batch are uniform.
|
||||
@ -2819,6 +2838,7 @@ class GPUModelRunner(
|
||||
else force_has_lora
|
||||
)
|
||||
|
||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||
dispatch_cudagraph = (
|
||||
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens=num_tokens,
|
||||
@ -2834,6 +2854,15 @@ class GPUModelRunner(
|
||||
num_tokens_padded, use_cascade_attn or has_encoder_output
|
||||
)
|
||||
num_tokens_padded = batch_descriptor.num_tokens
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
assert (
|
||||
batch_descriptor.num_tokens
|
||||
% self.vllm_config.parallel_config.tensor_parallel_size
|
||||
== 0
|
||||
), (
|
||||
"Sequence parallelism requires num_tokens to be "
|
||||
"a multiple of tensor parallel size"
|
||||
)
|
||||
|
||||
# Extra coordination when running data-parallel since we need to coordinate
|
||||
# across ranks
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user