mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 15:55:01 +08:00
[V1][BugFix] Fix edge case in VLM scheduling (#12065)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
9ddac56311
commit
b7ee940a82
@ -373,18 +373,22 @@ class Scheduler:
|
||||
if self.encoder_cache_manager.has_cache(request, i):
|
||||
# The encoder input is already computed and cached.
|
||||
continue
|
||||
if not self.encoder_cache_manager.can_allocate(request, i):
|
||||
# The encoder cache is full. We can only schedule the decoder
|
||||
# tokens just before the encoder input.
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
if num_encoder_tokens > encoder_budget:
|
||||
# The encoder budget is exhausted. We can only schedule the
|
||||
# decoder tokens up until the encoder input.
|
||||
# NOTE(woosuk): We assume that the encoder tokens should be
|
||||
# processed altogether, as the encoder usually uses
|
||||
if (not self.encoder_cache_manager.can_allocate(request, i)
|
||||
or num_encoder_tokens > encoder_budget):
|
||||
# The encoder cache is full or the encoder budget is exhausted.
|
||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||
# be processed altogether, as the encoder usually uses
|
||||
# bidirectional attention.
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
if num_computed_tokens < start_pos:
|
||||
# We only schedule the decoder tokens just before the
|
||||
# encoder input.
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
else:
|
||||
# Because of prefix caching, num_computed_tokens is greater
|
||||
# than start_pos even though its encoder input is not
|
||||
# available. In this case, we can't schedule any token for
|
||||
# the request in this step.
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
encoder_budget -= num_encoder_tokens
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user