[Bugfix] Fix scheduling when repeated images in one request (#23544)

Signed-off-by: Roger Wang <hey@rogerw.me>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.me>
Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
Roger Wang 2025-08-26 02:46:28 -07:00 committed by GitHub
parent 9b5f64238f
commit b5d34af328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 96 additions and 39 deletions

View File

@ -22,7 +22,7 @@ def test_basic_allocate_and_reuse():
req = MockRequest("r1", ["imgA"], [4])
assert not cache.check_and_update_cache(req, 0)
assert cache.try_allocate(req, 0, int(1e9))
assert cache.can_allocate(req, 0, int(1e9), 0)
cache.allocate(req, 0)
@ -44,7 +44,7 @@ def test_freeing_decreases_refcount_and_moves_to_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req2", ["img3"], [5])
assert manager.try_allocate(req, 0, int(1e9))
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
assert len(manager.cached["img3"]) == 1
@ -60,10 +60,10 @@ def test_free_request_frees_all_inputs():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req3", ["a", "b"], [2, 3])
assert manager.try_allocate(req, 0, int(1e9))
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
assert manager.try_allocate(req, 1, int(1e9))
assert manager.can_allocate(req, 1, int(1e9), 0)
manager.allocate(req, 1)
assert len(manager.cached["a"]) == 1
@ -84,11 +84,11 @@ def test_eviction_when_cache_is_full():
req1 = MockRequest("req1", ["x"], [6])
req2 = MockRequest("req2", ["y"], [5])
assert manager.try_allocate(req1, 0, int(1e9))
assert manager.can_allocate(req1, 0, int(1e9), 0)
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
assert manager.try_allocate(req2, 0, int(1e9))
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
# 'x' should have been evicted.
@ -100,10 +100,10 @@ def test_get_cached_input_ids():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
assert manager.try_allocate(req, 0, int(1e9))
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
assert manager.try_allocate(req, 2, int(1e9))
assert manager.can_allocate(req, 2, int(1e9), 0)
manager.allocate(req, 2)
cached_ids = manager.get_cached_input_ids(req)
@ -114,7 +114,7 @@ def test_has_cache_restores_from_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqY", ["imgZ"], [4])
assert manager.try_allocate(req, 0, int(1e9))
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
manager.free_encoder_input(req, 0)
@ -131,14 +131,41 @@ def test_get_freed_mm_hashes_clears_freed_list():
req1 = MockRequest("reqA", ["a"], [5])
req2 = MockRequest("reqB", ["b"], [6])
assert manager.try_allocate(req1, 0, int(1e9))
assert manager.can_allocate(req1, 0, int(1e9), 0)
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
# Should trigger eviction of 'a'.
assert manager.try_allocate(req2, 0, int(1e9))
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
freed = manager.get_freed_mm_hashes()
assert "a" in freed
assert manager.get_freed_mm_hashes() == []
def test_schedule_request_multi_images_respect_space_limit():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqA", ["a", "b"], [5, 6])
compute_budget = 100
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
assert not manager.can_allocate(req, 1, compute_budget,
num_tokens_to_schedule)
def test_schedule_request_multi_images_respect_compute_limit():
manager = EncoderCacheManager(cache_size=100)
req = MockRequest("reqA", ["a", "b"], [5, 6])
compute_budget = 10
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
assert not manager.can_allocate(req, 1, compute_budget,
num_tokens_to_schedule)

View File

@ -99,8 +99,9 @@ class EncoderCacheManager:
self.cached[mm_hash].add(request.request_id)
return True
def try_allocate(self, request: Request, input_id: int,
encoder_budget: int) -> bool:
def can_allocate(self, request: Request, input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
@ -116,6 +117,10 @@ class EncoderCacheManager:
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder tokens allowed to be
computed when this method is invoked.
num_tokens_to_schedule: Number of tokens already scheduled to be
allocated with cache space when this method is invoked.
Returns:
True if there's enough capacity to hold the encoder output for this
@ -128,13 +133,13 @@ class EncoderCacheManager:
num_tokens = request.get_num_encoder_tokens(input_id)
# Not enough compute budget
if num_tokens > encoder_budget:
if num_tokens > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
# Enough free slots
if num_tokens <= self.num_free_slots:
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
# Not enough reclaimable slots
@ -149,8 +154,6 @@ class EncoderCacheManager:
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
def allocate(self, request: Request, input_id: int) -> None:
@ -161,19 +164,24 @@ class EncoderCacheManager:
the model runner; this method updates the manager's bookkeeping.
Note:
This method assumes try_allocate() returned True for the same input.
This method assumes can_allocate() returned True for the same input.
"""
# Encoder cache space budget should be already updated for the
# multimodal input and non-negative after try_allocate() is called.
assert self.num_free_slots >= 0
assert self.num_freeable_slots >= 0
mm_hash = request.mm_hashes[input_id]
request_id = request.request_id
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_tokens
assert self.num_freeable_slots >= num_encoder_tokens
self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_tokens
self.num_freeable_slots -= num_encoder_tokens
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.

View File

@ -182,7 +182,7 @@ class Scheduler(SchedulerInterface):
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
encoder_compute_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
@ -211,12 +211,13 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
@ -298,7 +299,7 @@ class Scheduler(SchedulerInterface):
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
encoder_compute_budget = new_encoder_compute_budget
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
@ -382,7 +383,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
@ -413,10 +414,10 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
@ -495,7 +496,7 @@ class Scheduler(SchedulerInterface):
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
encoder_compute_budget = new_encoder_compute_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
@ -658,7 +659,7 @@ class Scheduler(SchedulerInterface):
request: Request,
num_computed_tokens: int,
num_new_tokens: int,
encoder_budget: int,
encoder_compute_budget: int,
) -> tuple[list[int], int, int]:
"""
Determine which encoder inputs need to be scheduled in the current step,
@ -680,11 +681,17 @@ class Scheduler(SchedulerInterface):
blocks and externally cached blocks (via KVConnector).
"""
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_budget
return [], num_new_tokens, encoder_compute_budget
encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
assert len(mm_positions) > 0
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
@ -695,13 +702,20 @@ class Scheduler(SchedulerInterface):
if start_pos >= num_computed_tokens + num_new_tokens:
# The encoder input is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue
# The same encoder input has already been scheduled in the current
# step.
if request.mm_hashes[i] in mm_hashes_to_schedule:
continue
if self.encoder_cache_manager.check_and_update_cache(request, i):
# The encoder input is already computed and cached.
# The encoder input is already computed and cached from a
# previous step.
continue
# If no encoder input chunking is allowed, we do not want to
@ -714,8 +728,9 @@ class Scheduler(SchedulerInterface):
num_new_tokens = start_pos - num_computed_tokens
break
if not self.encoder_cache_manager.try_allocate(
request, i, encoder_budget):
if not self.encoder_cache_manager.can_allocate(
request, i, encoder_compute_budget,
num_tokens_to_schedule):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
# be processed altogether, as the encoder usually uses
@ -732,9 +747,16 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
encoder_budget -= num_encoder_tokens
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
mm_hashes_to_schedule.add(request.mm_hashes[i])
encoder_inputs_to_schedule.append(i)
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
return (
encoder_inputs_to_schedule,
num_new_tokens,
encoder_compute_budget,
)
def get_grammar_bitmask(
self,