[Core][Multimodal] Track encode cache entries by mm_hash and enable embedding sharing between requests (#22711)

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Chenguang Zheng 2025-08-25 15:41:17 +08:00 committed by GitHub
parent 712d0f88d8
commit d765cf01fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 365 additions and 154 deletions

View File

@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
# ------------------ Mock Classes ------------------ #
class MockRequest:
def __init__(self, request_id, mm_hashes, token_counts):
self.request_id = request_id
self.mm_hashes = mm_hashes
self._token_counts = token_counts
def get_num_encoder_tokens(self, input_id: int) -> int:
return self._token_counts[input_id]
# ------------------ Unit Tests ------------------ #
def test_basic_allocate_and_reuse():
cache = EncoderCacheManager(cache_size=10)
req = MockRequest("r1", ["imgA"], [4])
assert not cache.check_and_update_cache(req, 0)
assert cache.try_allocate(req, 0, int(1e9))
cache.allocate(req, 0)
assert cache.check_and_update_cache(req, 0)
assert "r1" in cache.cached["imgA"]
assert cache.num_free_slots == 6
# Free twice to bring refcount to 0.
cache.free_encoder_input(req, 0)
cache.free_encoder_input(req, 0)
assert not cache.cached["imgA"]
assert "imgA" in cache.freeable
assert cache.num_freeable_slots == 10
assert cache.num_free_slots == 6
def test_freeing_decreases_refcount_and_moves_to_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req2", ["img3"], [5])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert len(manager.cached["img3"]) == 1
manager.free_encoder_input(req, 0)
assert not manager.cached["img3"]
assert "img3" in manager.freeable
assert manager.num_freeable_slots == 10
def test_free_request_frees_all_inputs():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req3", ["a", "b"], [2, 3])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert manager.try_allocate(req, 1, int(1e9))
manager.allocate(req, 1)
assert len(manager.cached["a"]) == 1
assert len(manager.cached["b"]) == 1
manager.free(req)
assert not manager.cached["a"]
assert not manager.cached["b"]
assert "a" in manager.freeable
assert "b" in manager.freeable
assert manager.num_freeable_slots == 10
def test_eviction_when_cache_is_full():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["x"], [6])
req2 = MockRequest("req2", ["y"], [5])
assert manager.try_allocate(req1, 0, int(1e9))
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
assert manager.try_allocate(req2, 0, int(1e9))
manager.allocate(req2, 0)
# 'x' should have been evicted.
assert "x" not in manager.cached
assert "x" in manager.get_freed_mm_hashes()
def test_get_cached_input_ids():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert manager.try_allocate(req, 2, int(1e9))
manager.allocate(req, 2)
cached_ids = manager.get_cached_input_ids(req)
assert cached_ids == {0, 2}
def test_has_cache_restores_from_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqY", ["imgZ"], [4])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
manager.free_encoder_input(req, 0)
# Should restore from freeable.
assert manager.check_and_update_cache(req, 0)
assert len(manager.cached["imgZ"]) == 1
assert "imgZ" not in manager.freeable
assert manager.num_freeable_slots == 6
def test_get_freed_mm_hashes_clears_freed_list():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("reqA", ["a"], [5])
req2 = MockRequest("reqB", ["b"], [6])
assert manager.try_allocate(req1, 0, int(1e9))
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
# Should trigger eviction of 'a'.
assert manager.try_allocate(req2, 0, int(1e9))
manager.allocate(req2, 0)
freed = manager.get_freed_mm_hashes()
assert "a" in freed
assert manager.get_freed_mm_hashes() == []

View File

@ -338,7 +338,7 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None) grammar_bitmask=None)
@ -391,7 +391,7 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -443,7 +443,7 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -490,7 +490,7 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None) grammar_bitmask=None)

View File

@ -143,7 +143,11 @@ def create_requests(
mm_position = mm_positions[i] mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position) mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position) # Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
mm_hashes = [
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
]
else: else:
mm_position = None mm_position = None
mm_kwargs = None mm_kwargs = None

View File

@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )

View File

@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int):
pooling_params=None, pooling_params=None,
mm_kwargs=[], mm_kwargs=[],
mm_positions=[], mm_positions=[],
mm_hashes=[],
block_ids=([], ), block_ids=([], ),
generator=None, generator=None,
num_computed_tokens=len(output_token_ids), num_computed_tokens=len(output_token_ids),

View File

@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )
@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_mm_hashes=[],
structured_output_request_ids={}, structured_output_request_ids={},
grammar_bitmask=None, grammar_bitmask=None,
) )

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -31,34 +33,52 @@ class EncoderCacheManager:
within requests, allowing for fine-grained memory management and enabling within requests, allowing for fine-grained memory management and enabling
chunked processing of multimodal inputs. chunked processing of multimodal inputs.
Note that no caching is shared between requests at this time. If the same Cache is enabled to share embeddings of same multimodal data
input is used across multiple requests, it will be reprocessed for each item (identified by their hash value) between different requests,
request. and eviction takes place at allocation time when there's no free
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
Args: Args:
cache_size: Limit the size of the cache, measured by the number of cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence. tokens from the input sequence.
Attributes: Attributes:
cache_size: Total cache capacity in encoder tokens cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens num_free_slots: Current available cache capacity in encoder tokens.
cached: Mapping from request_id to set of cached input_ids for that num_freeable_slots: Capacity that can be immediately reclaimed by
request evicting entries with zero references (in encoder tokens).
freed: List of (request_id, input_id) pairs that were recently freed. cached: Mapping from mm_hash to a set of request IDs that currently
This is cleared after every call to get_freed_ids(). reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
last call to get_freed_mm_hashes(). This list is cleared on return.
""" """
def __init__(self, cache_size: int): def __init__(self, cache_size: int):
self.cache_size = cache_size self.cache_size = cache_size
self.num_free_slots = cache_size self.num_free_slots = cache_size
# req_id -> cached input ids self.num_freeable_slots = cache_size
self.cached: dict[str, set[int]] = {}
# list of [req_id, input_id]
self.freed: list[tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool: # mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
"""Check if encoder output for a specific multimodal input is cached. """Check if encoder output for a specific multimodal input is cached.
If the encoder output is cached, update `cached` to add the request id
to the set of request ids that reference the cached encoder output.
If the encoder output was previously not referenced by any request,
update `freeable` and `num_freeable_slots` accordingly.
Args: Args:
request: The request containing the multimodal input request: The request containing the multimodal input
input_id: Index of the multimodal input within the request input_id: Index of the multimodal input within the request
@ -66,103 +86,151 @@ class EncoderCacheManager:
Returns: Returns:
True if the encoder output for this input is already cached True if the encoder output for this input is already cached
""" """
req_id = request.request_id mm_hash = request.mm_hashes[input_id]
return req_id in self.cached and input_id in self.cached[req_id] # Not cached at all
if mm_hash not in self.cached:
return False
def can_allocate(self, request: Request, input_id: int) -> bool: # Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
self.cached[mm_hash].add(request.request_id)
return True
def try_allocate(self, request: Request, input_id: int,
encoder_budget: int) -> bool:
"""Check if there's sufficient cache space for a multimodal input. """Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
If there is not enough free space in `num_free_slots` but there is
enough reclaimable space in `num_freeable_slots`, entries will be
evicted from `freeable` (their mm_hash appended to `freed`) until
enough space is available, and then this method returns True.
Older entries are evicted first.
Returns False only if the requested number of tokens exceeds both
the free and reclaimable capacities combined.
Args: Args:
request: The request containing the multimodal input request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request input_id: Index of the multimodal input within the request.
Returns: Returns:
True if there's enough free cache space to store the encoder output True if there's enough capacity to hold the encoder output for this
for this multimodal input input (possibly after reclaiming `freeable` entries); otherwise
False.
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
""" """
num_tokens = request.get_num_encoder_tokens(input_id) num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
# Not enough compute budget
if num_tokens > encoder_budget:
return False
# Enough free slots
if num_tokens <= self.num_free_slots:
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
def allocate(self, request: Request, input_id: int) -> None: def allocate(self, request: Request, input_id: int) -> None:
"""Allocate cache space for a multimodal input's encoder output. """Allocate cache space for a multimodal input's encoder output.
This method reserves cache space for storing the encoder output of This reserves cache space for storing the encoder output of the
the specified multimodal input. The actual encoder output storage specified multimodal input. The actual encoder output storage happens in
happens in the model runner, but this method ensures the cache the model runner; this method updates the manager's bookkeeping.
manager tracks the allocation.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
Note: Note:
This method assumes can_allocate() returned True for the same This method assumes try_allocate() returned True for the same input.
request and input_id. It will reduce available cache space.
""" """
req_id = request.request_id # Encoder cache space budget should be already updated for the
if req_id not in self.cached: # multimodal input and non-negative after try_allocate() is called.
self.cached[req_id] = set() assert self.num_free_slots >= 0
self.cached[req_id].add(input_id) assert self.num_freeable_slots >= 0
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
mm_hash = request.mm_hashes[input_id]
request_id = request.request_id
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
self.cached[mm_hash].add(request_id)
def get_cached_input_ids(self, request: Request) -> set[int]: def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request. """Get all cached multimodal input IDs for a request.
Args: Returns the set of input IDs whose `mm_hash` exists in the cache map.
request: The request to query This includes entries that are currently unreferenced (and thus present
in `freeable`); for such entries, freeing for this request will be a
Returns: no-op.
Set of input_ids that have cached encoder outputs for this request.
Returns empty set if no inputs are cached for this request.
""" """
return self.cached.get(request.request_id, set()) return {
input_id
for input_id in range(len(request.mm_hashes))
if request.mm_hashes[input_id] in self.cached
}
def free_encoder_input(self, request: Request, input_id: int) -> None: def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free cache space for a single multimodal input's encoder output. """Free the request's reference to the encoder input (`mm_data`)
This method is called when: When the reference set for the corresponding `mm_hash` becomes empty,
- The encoder output has been fully consumed by the decoder and is the entry is appended to `freeable` and `num_freeable_slots` is
no longer needed (e.g., in vision-language models after image increased by the number of encoder tokens for that input.
tokens are processed)
- A request is being cancelled or aborted
Args: The entry is NOT physically freed until capacity is needed (e.g., by
request: The request containing the multimodal input `can_allocate`).
input_id: Index of the multimodal input to free from cache
""" """
req_id = request.request_id req_id = request.request_id
if req_id not in self.cached: mm_hash = request.mm_hashes[input_id]
# The mm_hash not in cache or the req_id set is empty
if not self.cached.get(mm_hash, None):
return return
self.cached[mm_hash].discard(req_id)
self.cached[req_id].discard(input_id) if not self.cached[mm_hash]:
if len(self.cached[req_id]) == 0: num_tokens = request.get_num_encoder_tokens(input_id)
del self.cached[req_id] self.freeable[mm_hash] = num_tokens
self.num_free_slots += request.get_num_encoder_tokens(input_id) self.num_freeable_slots += num_tokens
self.freed.append((req_id, input_id))
def free(self, request: Request) -> None: def free(self, request: Request) -> None:
"""Free all cached encoder outputs for a request. """Free all encoder input cache reference held by *request*.
This method is typically called when a request is finished, cancelled, For each cached input ID, `free_encoder_input` is invoked.
or aborted, and all its encoder outputs should be freed from cache. The data stays in memory until eviction is triggered by a future
attempt allocation called by 'can_allocate'.
Args: Typically called when a request is finished, cancelled, or aborted.
request: The request whose encoder outputs should be freed
""" """
input_ids = self.get_cached_input_ids(request).copy() input_ids = self.get_cached_input_ids(request).copy()
for input_id in input_ids: for input_id in input_ids:
self.free_encoder_input(request, input_id) self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> list[tuple[str, int]]: def get_freed_mm_hashes(self) -> list[str]:
"""Get and clear the list of recently freed encoder cache entries. """Get and clear the list of recently freed encoder cache entries.
This method returns all encoder cache entries that were freed since
the last call to this method. It's used by the scheduler to notify
workers about which encoder outputs can be removed from their caches.
Returns: Returns:
List of (request_id, input_id) tuples that were freed since the List of mm_hash strings that were actually evicted since the last
last call. The internal freed list is cleared after this call. call to be used by the scheduler to notify workers about which
encoder outputs can be removed from their caches. The internal
list is cleared after this call.
""" """
freed = self.freed freed = self.freed
self.freed = [] self.freed = []
@ -177,16 +245,11 @@ def compute_encoder_budget(
"""Compute the encoder cache budget based on the model and scheduler """Compute the encoder cache budget based on the model and scheduler
configurations. configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns: Returns:
- Compute budget for encoder execution, in unit of number of tokens - Compute budget for encoder execution, measured in number of tokens
in the input sequence. from the input sequence.
- Space budget for encoder cache size, in unit of number of tokens - Space budget for encoder cache size, measured in number of tokens
in the input sequence. from the input sequence.
""" """
if mm_registry.supports_multimodal_inputs(model_config): if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry \ max_tokens_by_modality = mm_registry \
@ -231,10 +294,10 @@ def compute_mm_encoder_budget(
non-text modality. non-text modality.
Returns: Returns:
- Compute budget for encoder execution, in unit of number of tokens - Compute budget for encoder execution, measured in number of tokens
in the input sequence. from the input sequence.
- Space budget for encoder cache size, in unit of number of tokens - Space budget for encoder cache size, measured in number of tokens
in the input sequence. from the input sequence.
""" """
if not max_tokens_by_modality: if not max_tokens_by_modality:

View File

@ -143,9 +143,9 @@ class SchedulerOutput:
# steps. This is used to notify the workers about the finished requests # steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests. # so that they can free the cached states for those requests.
finished_req_ids: set[str] finished_req_ids: set[str]
# list of (req_id, encoder_input_index) tuples. # list of mm_hash strings associated with the encoder outputs to be
# Used to free the encoder cache. # freed from the encoder cache.
free_encoder_input_ids: list[tuple[str, int]] free_encoder_mm_hashes: list[str]
# Dict of request ids to their index within the batch # Dict of request ids to their index within the batch
# for filling the next token bitmask # for filling the next token bitmask

View File

@ -252,6 +252,7 @@ class Scheduler(SchedulerInterface):
preempted_req = self.running.pop() preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
if self.log_stats: if self.log_stats:
@ -550,7 +551,8 @@ class Scheduler(SchedulerInterface):
# It contains the request IDs that are finished in between # It contains the request IDs that are finished in between
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids, structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask, grammar_bitmask=grammar_bitmask,
) )
@ -698,7 +700,7 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache. # in the decoder's KV cache.
continue continue
if self.encoder_cache_manager.has_cache(request, i): if self.encoder_cache_manager.check_and_update_cache(request, i):
# The encoder input is already computed and cached. # The encoder input is already computed and cached.
continue continue
@ -712,8 +714,8 @@ class Scheduler(SchedulerInterface):
num_new_tokens = start_pos - num_computed_tokens num_new_tokens = start_pos - num_computed_tokens
break break
if (not self.encoder_cache_manager.can_allocate(request, i) if not self.encoder_cache_manager.try_allocate(
or num_encoder_tokens > encoder_budget): request, i, encoder_budget):
# The encoder cache is full or the encoder budget is exhausted. # The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should # NOTE(woosuk): We assume that the encoder input tokens should
# be processed altogether, as the encoder usually uses # be processed altogether, as the encoder usually uses

View File

@ -33,6 +33,7 @@ class CachedRequestState:
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_kwargs: list[MultiModalKwargsItem] mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
mm_hashes: list[str]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator] generator: Optional[torch.Generator]

View File

@ -176,8 +176,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = [] self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig # self.kv_cache_config: KVCacheConfig
# req_id -> (input_id -> encoder_output) # mm_hash -> encoder_output
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, torch.Tensor] = {}
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
# Set up speculative decoding. # Set up speculative decoding.
@ -436,7 +436,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states. # Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None) self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch. # Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and # NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and # scheduled_req_ids overlap. This happens when a request is aborted and
@ -447,12 +446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
# Free the cached encoder outputs. # Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids: for mm_hash in scheduler_output.free_encoder_mm_hashes:
encoder_outputs = self.encoder_cache.get(req_id) self.encoder_cache.pop(mm_hash, None)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the unscheduled requests from the persistent batch. # Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests # NOTE(woosuk): The unscheduled requests are either preempted requests
@ -496,6 +491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs, mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=pooling_params,
generator=generator, generator=generator,
@ -1161,17 +1157,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]() # list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append( mm_hashes_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id])) (mm_hash, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
@ -1204,15 +1201,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for output in curr_group_outputs: for output in curr_group_outputs:
encoder_outputs.append(output) encoder_outputs.append(output)
# Cache the encoder outputs. # Cache the encoder outputs by mm_hash
for (req_id, input_id, pos_info), output in zip( for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
req_ids_pos, self.encoder_cache[mm_hash] = scatter_mm_placeholders(
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output, output,
is_embed=pos_info.is_embed, is_embed=pos_info.is_embed,
) )
@ -1230,6 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens = \ num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
for i, pos_info in enumerate(mm_positions): for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset start_pos = pos_info.offset
num_encoder_tokens = pos_info.length num_encoder_tokens = pos_info.length
@ -1249,11 +1241,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx = max(num_computed_tokens - start_pos, 0) start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min( end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens, num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens) num_encoder_tokens,
)
assert start_idx < end_idx assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id] mm_hash = mm_hashes[i]
encoder_output = self.encoder_cache[req_id][i] encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None: if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx] is_embed = is_embed[start_idx:end_idx]

View File

@ -208,8 +208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Lazy initialization # Lazy initialization
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output) # mm_hash -> encoder_output
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, torch.Tensor] = {}
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
@ -342,7 +342,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states. # Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None) self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch. # Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and # NOTE(woosuk): There could be an edge case where finished_req_ids and
@ -357,12 +356,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
removed_req_indices.append(req_index) removed_req_indices.append(req_index)
# Free the cached encoder outputs. # Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids: for mm_hash in scheduler_output.free_encoder_mm_hashes:
encoder_outputs = self.encoder_cache.get(req_id) self.encoder_cache.pop(mm_hash, None)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the unscheduled requests from the persistent batch. # Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests # NOTE(woosuk): The unscheduled requests are either preempted requests
@ -394,6 +389,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs, mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
generator=None, generator=None,
@ -845,14 +841,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]() # List of tuple (mm_hash, pos_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append( mm_hashes_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id])) (mm_hash, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
@ -895,15 +893,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE (NickLucche) here we diverge from logic in other runners, as we # NOTE (NickLucche) here we diverge from logic in other runners, as we
# assume to only have whole mm items to process. Hence we avoid the # assume to only have whole mm items to process. Hence we avoid the
# intrinsic dynamism that `scatter_mm_placeholders` introduces. # intrinsic dynamism that `scatter_mm_placeholders` introduces.
for (req_id, input_id, pos_info), output in zip( for (mm_hash, pos_info), output in zip(
req_ids_pos, mm_hashes_pos,
encoder_outputs, encoder_outputs,
): ):
if req_id not in self.encoder_cache: if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {} self.encoder_cache[req_id] = {}
assert pos_info.is_embed is None, "Expected all positions to be"\ assert pos_info.is_embed is None, "Expected all positions to be"\
" contiguous and embeddings." " contiguous and embeddings."
self.encoder_cache[req_id][input_id] = output self.encoder_cache[mm_hash] = output
def _gather_mm_embeddings( def _gather_mm_embeddings(
self, self,
@ -916,6 +914,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
# TODO unroll loop and assume/enforce --disable_chunked_mm_input # TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as # NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid # we assume to only have whole mm items to process. Hence we avoid
@ -936,11 +935,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the decoder's KV cache. # in the decoder's KV cache.
continue continue
assert req_id in self.encoder_cache mm_hash = mm_hashes[i]
assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
assert pos_info.is_embed is None, "Expected all positions to"\ assert pos_info.is_embed is None, "Expected all positions to"\
" be contiguous and embeddings." " be contiguous and embeddings."
encoder_output = self.encoder_cache[req_id][i] encoder_output = self.encoder_cache[mm_hash]
mm_embeds.append(encoder_output) mm_embeds.append(encoder_output)
return mm_embeds return mm_embeds