mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 12:27:52 +08:00
[Core][Multimodal] Track encode cache entries by mm_hash and enable embedding sharing between requests (#22711)
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
712d0f88d8
commit
d765cf01fe
144
tests/v1/core/test_encoder_cache_manager.py
Normal file
144
tests/v1/core/test_encoder_cache_manager.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ Mock Classes ------------------ #
|
||||||
|
class MockRequest:
|
||||||
|
|
||||||
|
def __init__(self, request_id, mm_hashes, token_counts):
|
||||||
|
self.request_id = request_id
|
||||||
|
self.mm_hashes = mm_hashes
|
||||||
|
self._token_counts = token_counts
|
||||||
|
|
||||||
|
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||||
|
return self._token_counts[input_id]
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------ Unit Tests ------------------ #
|
||||||
|
def test_basic_allocate_and_reuse():
|
||||||
|
cache = EncoderCacheManager(cache_size=10)
|
||||||
|
req = MockRequest("r1", ["imgA"], [4])
|
||||||
|
|
||||||
|
assert not cache.check_and_update_cache(req, 0)
|
||||||
|
assert cache.try_allocate(req, 0, int(1e9))
|
||||||
|
|
||||||
|
cache.allocate(req, 0)
|
||||||
|
|
||||||
|
assert cache.check_and_update_cache(req, 0)
|
||||||
|
assert "r1" in cache.cached["imgA"]
|
||||||
|
assert cache.num_free_slots == 6
|
||||||
|
|
||||||
|
# Free twice to bring refcount to 0.
|
||||||
|
cache.free_encoder_input(req, 0)
|
||||||
|
cache.free_encoder_input(req, 0)
|
||||||
|
|
||||||
|
assert not cache.cached["imgA"]
|
||||||
|
assert "imgA" in cache.freeable
|
||||||
|
assert cache.num_freeable_slots == 10
|
||||||
|
assert cache.num_free_slots == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_freeing_decreases_refcount_and_moves_to_freeable():
|
||||||
|
manager = EncoderCacheManager(cache_size=10)
|
||||||
|
req = MockRequest("req2", ["img3"], [5])
|
||||||
|
|
||||||
|
assert manager.try_allocate(req, 0, int(1e9))
|
||||||
|
manager.allocate(req, 0)
|
||||||
|
|
||||||
|
assert len(manager.cached["img3"]) == 1
|
||||||
|
|
||||||
|
manager.free_encoder_input(req, 0)
|
||||||
|
|
||||||
|
assert not manager.cached["img3"]
|
||||||
|
assert "img3" in manager.freeable
|
||||||
|
assert manager.num_freeable_slots == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_request_frees_all_inputs():
|
||||||
|
manager = EncoderCacheManager(cache_size=10)
|
||||||
|
req = MockRequest("req3", ["a", "b"], [2, 3])
|
||||||
|
|
||||||
|
assert manager.try_allocate(req, 0, int(1e9))
|
||||||
|
manager.allocate(req, 0)
|
||||||
|
|
||||||
|
assert manager.try_allocate(req, 1, int(1e9))
|
||||||
|
manager.allocate(req, 1)
|
||||||
|
|
||||||
|
assert len(manager.cached["a"]) == 1
|
||||||
|
assert len(manager.cached["b"]) == 1
|
||||||
|
|
||||||
|
manager.free(req)
|
||||||
|
|
||||||
|
assert not manager.cached["a"]
|
||||||
|
assert not manager.cached["b"]
|
||||||
|
assert "a" in manager.freeable
|
||||||
|
assert "b" in manager.freeable
|
||||||
|
assert manager.num_freeable_slots == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_eviction_when_cache_is_full():
|
||||||
|
manager = EncoderCacheManager(cache_size=10)
|
||||||
|
|
||||||
|
req1 = MockRequest("req1", ["x"], [6])
|
||||||
|
req2 = MockRequest("req2", ["y"], [5])
|
||||||
|
|
||||||
|
assert manager.try_allocate(req1, 0, int(1e9))
|
||||||
|
manager.allocate(req1, 0)
|
||||||
|
manager.free_encoder_input(req1, 0)
|
||||||
|
|
||||||
|
assert manager.try_allocate(req2, 0, int(1e9))
|
||||||
|
manager.allocate(req2, 0)
|
||||||
|
|
||||||
|
# 'x' should have been evicted.
|
||||||
|
assert "x" not in manager.cached
|
||||||
|
assert "x" in manager.get_freed_mm_hashes()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cached_input_ids():
|
||||||
|
manager = EncoderCacheManager(cache_size=10)
|
||||||
|
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
|
||||||
|
|
||||||
|
assert manager.try_allocate(req, 0, int(1e9))
|
||||||
|
manager.allocate(req, 0)
|
||||||
|
|
||||||
|
assert manager.try_allocate(req, 2, int(1e9))
|
||||||
|
manager.allocate(req, 2)
|
||||||
|
|
||||||
|
cached_ids = manager.get_cached_input_ids(req)
|
||||||
|
assert cached_ids == {0, 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_cache_restores_from_freeable():
|
||||||
|
manager = EncoderCacheManager(cache_size=10)
|
||||||
|
req = MockRequest("reqY", ["imgZ"], [4])
|
||||||
|
|
||||||
|
assert manager.try_allocate(req, 0, int(1e9))
|
||||||
|
manager.allocate(req, 0)
|
||||||
|
|
||||||
|
manager.free_encoder_input(req, 0)
|
||||||
|
|
||||||
|
# Should restore from freeable.
|
||||||
|
assert manager.check_and_update_cache(req, 0)
|
||||||
|
assert len(manager.cached["imgZ"]) == 1
|
||||||
|
assert "imgZ" not in manager.freeable
|
||||||
|
assert manager.num_freeable_slots == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_freed_mm_hashes_clears_freed_list():
|
||||||
|
manager = EncoderCacheManager(cache_size=10)
|
||||||
|
req1 = MockRequest("reqA", ["a"], [5])
|
||||||
|
req2 = MockRequest("reqB", ["b"], [6])
|
||||||
|
|
||||||
|
assert manager.try_allocate(req1, 0, int(1e9))
|
||||||
|
manager.allocate(req1, 0)
|
||||||
|
manager.free_encoder_input(req1, 0)
|
||||||
|
|
||||||
|
# Should trigger eviction of 'a'.
|
||||||
|
assert manager.try_allocate(req2, 0, int(1e9))
|
||||||
|
manager.allocate(req2, 0)
|
||||||
|
|
||||||
|
freed = manager.get_freed_mm_hashes()
|
||||||
|
assert "a" in freed
|
||||||
|
assert manager.get_freed_mm_hashes() == []
|
||||||
@ -338,7 +338,7 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None)
|
grammar_bitmask=None)
|
||||||
|
|
||||||
@ -391,7 +391,7 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -443,7 +443,7 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -490,7 +490,7 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None)
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
|||||||
@ -143,7 +143,11 @@ def create_requests(
|
|||||||
mm_position = mm_positions[i]
|
mm_position = mm_positions[i]
|
||||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||||
mm_kwargs = [mm_item] * len(mm_position)
|
mm_kwargs = [mm_item] * len(mm_position)
|
||||||
mm_hashes = ["hash"] * len(mm_position)
|
# Dummy hash for each mm item should be unique
|
||||||
|
# since encoder cache tracks entries by hash
|
||||||
|
mm_hashes = [
|
||||||
|
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
mm_position = None
|
mm_position = None
|
||||||
mm_kwargs = None
|
mm_kwargs = None
|
||||||
|
|||||||
@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids={req_id},
|
finished_req_ids={req_id},
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
|||||||
pooling_params=None,
|
pooling_params=None,
|
||||||
mm_kwargs=[],
|
mm_kwargs=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
|
mm_hashes=[],
|
||||||
block_ids=([], ),
|
block_ids=([], ),
|
||||||
generator=None,
|
generator=None,
|
||||||
num_computed_tokens=len(output_token_ids),
|
num_computed_tokens=len(output_token_ids),
|
||||||
|
|||||||
@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids={req_id},
|
finished_req_ids={req_id},
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
|
|||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_mm_hashes=[],
|
||||||
structured_output_request_ids={},
|
structured_output_request_ids={},
|
||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@ -31,34 +33,52 @@ class EncoderCacheManager:
|
|||||||
within requests, allowing for fine-grained memory management and enabling
|
within requests, allowing for fine-grained memory management and enabling
|
||||||
chunked processing of multimodal inputs.
|
chunked processing of multimodal inputs.
|
||||||
|
|
||||||
Note that no caching is shared between requests at this time. If the same
|
Cache is enabled to share embeddings of same multimodal data
|
||||||
input is used across multiple requests, it will be reprocessed for each
|
item (identified by their hash value) between different requests,
|
||||||
request.
|
and eviction takes place at allocation time when there's no free
|
||||||
|
space for new embeddings.
|
||||||
|
Oldest cached embeddings with no request referenced will be first evicted.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_size: Limit the size of the cache, measured by the number of
|
cache_size: Limit the size of the cache, measured by the number of
|
||||||
tokens from the input sequence.
|
tokens from the input sequence.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
cache_size: Total cache capacity in encoder tokens
|
cache_size: Total cache capacity in encoder tokens.
|
||||||
num_free_slots: Current available cache capacity in encoder tokens
|
num_free_slots: Current available cache capacity in encoder tokens.
|
||||||
cached: Mapping from request_id to set of cached input_ids for that
|
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||||
request
|
evicting entries with zero references (in encoder tokens).
|
||||||
freed: List of (request_id, input_id) pairs that were recently freed.
|
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||||
This is cleared after every call to get_freed_ids().
|
reference the cached entry. If the set is empty, the entry exists
|
||||||
|
but is not referenced by any request and is eligible for
|
||||||
|
reclamation.
|
||||||
|
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
||||||
|
whose no current running request is needed and that can be freed to
|
||||||
|
make space when needed.
|
||||||
|
freed: List of mm_hash strings that were actually evicted since the
|
||||||
|
last call to get_freed_mm_hashes(). This list is cleared on return.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_size: int):
|
def __init__(self, cache_size: int):
|
||||||
self.cache_size = cache_size
|
self.cache_size = cache_size
|
||||||
self.num_free_slots = cache_size
|
self.num_free_slots = cache_size
|
||||||
# req_id -> cached input ids
|
self.num_freeable_slots = cache_size
|
||||||
self.cached: dict[str, set[int]] = {}
|
|
||||||
# list of [req_id, input_id]
|
|
||||||
self.freed: list[tuple[str, int]] = []
|
|
||||||
|
|
||||||
def has_cache(self, request: Request, input_id: int) -> bool:
|
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||||
|
self.cached: dict[str, set[str]] = {}
|
||||||
|
|
||||||
|
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
||||||
|
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||||
|
self.freed: list[str] = []
|
||||||
|
|
||||||
|
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||||
"""Check if encoder output for a specific multimodal input is cached.
|
"""Check if encoder output for a specific multimodal input is cached.
|
||||||
|
|
||||||
|
If the encoder output is cached, update `cached` to add the request id
|
||||||
|
to the set of request ids that reference the cached encoder output.
|
||||||
|
If the encoder output was previously not referenced by any request,
|
||||||
|
update `freeable` and `num_freeable_slots` accordingly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The request containing the multimodal input
|
request: The request containing the multimodal input
|
||||||
input_id: Index of the multimodal input within the request
|
input_id: Index of the multimodal input within the request
|
||||||
@ -66,103 +86,151 @@ class EncoderCacheManager:
|
|||||||
Returns:
|
Returns:
|
||||||
True if the encoder output for this input is already cached
|
True if the encoder output for this input is already cached
|
||||||
"""
|
"""
|
||||||
req_id = request.request_id
|
mm_hash = request.mm_hashes[input_id]
|
||||||
return req_id in self.cached and input_id in self.cached[req_id]
|
# Not cached at all
|
||||||
|
if mm_hash not in self.cached:
|
||||||
|
return False
|
||||||
|
|
||||||
def can_allocate(self, request: Request, input_id: int) -> bool:
|
# Cached but currently not referenced by any request
|
||||||
|
if not self.cached[mm_hash]:
|
||||||
|
num_tokens = self.freeable.pop(mm_hash)
|
||||||
|
self.num_freeable_slots -= num_tokens
|
||||||
|
|
||||||
|
self.cached[mm_hash].add(request.request_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def try_allocate(self, request: Request, input_id: int,
|
||||||
|
encoder_budget: int) -> bool:
|
||||||
"""Check if there's sufficient cache space for a multimodal input.
|
"""Check if there's sufficient cache space for a multimodal input.
|
||||||
|
If there is, return True and update EncoderCacheManager state.
|
||||||
|
|
||||||
|
If there is not enough free space in `num_free_slots` but there is
|
||||||
|
enough reclaimable space in `num_freeable_slots`, entries will be
|
||||||
|
evicted from `freeable` (their mm_hash appended to `freed`) until
|
||||||
|
enough space is available, and then this method returns True.
|
||||||
|
Older entries are evicted first.
|
||||||
|
|
||||||
|
Returns False only if the requested number of tokens exceeds both
|
||||||
|
the free and reclaimable capacities combined.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The request containing the multimodal input
|
request: The request containing the multimodal input.
|
||||||
input_id: Index of the multimodal input within the request
|
input_id: Index of the multimodal input within the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if there's enough free cache space to store the encoder output
|
True if there's enough capacity to hold the encoder output for this
|
||||||
for this multimodal input
|
input (possibly after reclaiming `freeable` entries); otherwise
|
||||||
|
False.
|
||||||
|
|
||||||
|
Note: This method does not allocate physical memory for the encoder
|
||||||
|
output but only the state of EncoderCacheManager.
|
||||||
"""
|
"""
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||||
return num_tokens <= self.num_free_slots
|
|
||||||
|
# Not enough compute budget
|
||||||
|
if num_tokens > encoder_budget:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Enough free slots
|
||||||
|
if num_tokens <= self.num_free_slots:
|
||||||
|
self.num_free_slots -= num_tokens
|
||||||
|
self.num_freeable_slots -= num_tokens
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Not enough reclaimable slots
|
||||||
|
if num_tokens > self.num_freeable_slots:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Not enough free slots but enough reclaimable slots
|
||||||
|
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||||
|
# until model runner is notified by the scheduler output.
|
||||||
|
while num_tokens > self.num_free_slots:
|
||||||
|
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
||||||
|
del self.cached[mm_hash]
|
||||||
|
self.freed.append(mm_hash)
|
||||||
|
self.num_free_slots += num_free_token
|
||||||
|
self.num_free_slots -= num_tokens
|
||||||
|
self.num_freeable_slots -= num_tokens
|
||||||
|
return True
|
||||||
|
|
||||||
def allocate(self, request: Request, input_id: int) -> None:
|
def allocate(self, request: Request, input_id: int) -> None:
|
||||||
"""Allocate cache space for a multimodal input's encoder output.
|
"""Allocate cache space for a multimodal input's encoder output.
|
||||||
|
|
||||||
This method reserves cache space for storing the encoder output of
|
This reserves cache space for storing the encoder output of the
|
||||||
the specified multimodal input. The actual encoder output storage
|
specified multimodal input. The actual encoder output storage happens in
|
||||||
happens in the model runner, but this method ensures the cache
|
the model runner; this method updates the manager's bookkeeping.
|
||||||
manager tracks the allocation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The request containing the multimodal input
|
|
||||||
input_id: Index of the multimodal input within the request
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This method assumes can_allocate() returned True for the same
|
This method assumes try_allocate() returned True for the same input.
|
||||||
request and input_id. It will reduce available cache space.
|
|
||||||
"""
|
"""
|
||||||
req_id = request.request_id
|
# Encoder cache space budget should be already updated for the
|
||||||
if req_id not in self.cached:
|
# multimodal input and non-negative after try_allocate() is called.
|
||||||
self.cached[req_id] = set()
|
assert self.num_free_slots >= 0
|
||||||
self.cached[req_id].add(input_id)
|
assert self.num_freeable_slots >= 0
|
||||||
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
|
|
||||||
|
mm_hash = request.mm_hashes[input_id]
|
||||||
|
request_id = request.request_id
|
||||||
|
if mm_hash not in self.cached:
|
||||||
|
self.cached[mm_hash] = set()
|
||||||
|
|
||||||
|
self.cached[mm_hash].add(request_id)
|
||||||
|
|
||||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||||
"""Get all cached multimodal input IDs for a request.
|
"""Get all cached multimodal input IDs for a request.
|
||||||
|
|
||||||
Args:
|
Returns the set of input IDs whose `mm_hash` exists in the cache map.
|
||||||
request: The request to query
|
This includes entries that are currently unreferenced (and thus present
|
||||||
|
in `freeable`); for such entries, freeing for this request will be a
|
||||||
Returns:
|
no-op.
|
||||||
Set of input_ids that have cached encoder outputs for this request.
|
|
||||||
Returns empty set if no inputs are cached for this request.
|
|
||||||
"""
|
"""
|
||||||
return self.cached.get(request.request_id, set())
|
return {
|
||||||
|
input_id
|
||||||
|
for input_id in range(len(request.mm_hashes))
|
||||||
|
if request.mm_hashes[input_id] in self.cached
|
||||||
|
}
|
||||||
|
|
||||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||||
"""Free cache space for a single multimodal input's encoder output.
|
"""Free the request's reference to the encoder input (`mm_data`)
|
||||||
|
|
||||||
This method is called when:
|
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||||
- The encoder output has been fully consumed by the decoder and is
|
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||||
no longer needed (e.g., in vision-language models after image
|
increased by the number of encoder tokens for that input.
|
||||||
tokens are processed)
|
|
||||||
- A request is being cancelled or aborted
|
|
||||||
|
|
||||||
Args:
|
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||||
request: The request containing the multimodal input
|
`can_allocate`).
|
||||||
input_id: Index of the multimodal input to free from cache
|
|
||||||
"""
|
"""
|
||||||
req_id = request.request_id
|
req_id = request.request_id
|
||||||
if req_id not in self.cached:
|
mm_hash = request.mm_hashes[input_id]
|
||||||
|
# The mm_hash not in cache or the req_id set is empty
|
||||||
|
if not self.cached.get(mm_hash, None):
|
||||||
return
|
return
|
||||||
|
self.cached[mm_hash].discard(req_id)
|
||||||
self.cached[req_id].discard(input_id)
|
if not self.cached[mm_hash]:
|
||||||
if len(self.cached[req_id]) == 0:
|
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||||
del self.cached[req_id]
|
self.freeable[mm_hash] = num_tokens
|
||||||
self.num_free_slots += request.get_num_encoder_tokens(input_id)
|
self.num_freeable_slots += num_tokens
|
||||||
self.freed.append((req_id, input_id))
|
|
||||||
|
|
||||||
def free(self, request: Request) -> None:
|
def free(self, request: Request) -> None:
|
||||||
"""Free all cached encoder outputs for a request.
|
"""Free all encoder input cache reference held by *request*.
|
||||||
|
|
||||||
This method is typically called when a request is finished, cancelled,
|
For each cached input ID, `free_encoder_input` is invoked.
|
||||||
or aborted, and all its encoder outputs should be freed from cache.
|
The data stays in memory until eviction is triggered by a future
|
||||||
|
attempt allocation called by 'can_allocate'.
|
||||||
|
|
||||||
Args:
|
Typically called when a request is finished, cancelled, or aborted.
|
||||||
request: The request whose encoder outputs should be freed
|
|
||||||
"""
|
"""
|
||||||
input_ids = self.get_cached_input_ids(request).copy()
|
input_ids = self.get_cached_input_ids(request).copy()
|
||||||
for input_id in input_ids:
|
for input_id in input_ids:
|
||||||
self.free_encoder_input(request, input_id)
|
self.free_encoder_input(request, input_id)
|
||||||
|
|
||||||
def get_freed_ids(self) -> list[tuple[str, int]]:
|
def get_freed_mm_hashes(self) -> list[str]:
|
||||||
"""Get and clear the list of recently freed encoder cache entries.
|
"""Get and clear the list of recently freed encoder cache entries.
|
||||||
|
|
||||||
This method returns all encoder cache entries that were freed since
|
|
||||||
the last call to this method. It's used by the scheduler to notify
|
|
||||||
workers about which encoder outputs can be removed from their caches.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (request_id, input_id) tuples that were freed since the
|
List of mm_hash strings that were actually evicted since the last
|
||||||
last call. The internal freed list is cleared after this call.
|
call to be used by the scheduler to notify workers about which
|
||||||
|
encoder outputs can be removed from their caches. The internal
|
||||||
|
list is cleared after this call.
|
||||||
"""
|
"""
|
||||||
freed = self.freed
|
freed = self.freed
|
||||||
self.freed = []
|
self.freed = []
|
||||||
@ -177,16 +245,11 @@ def compute_encoder_budget(
|
|||||||
"""Compute the encoder cache budget based on the model and scheduler
|
"""Compute the encoder cache budget based on the model and scheduler
|
||||||
configurations.
|
configurations.
|
||||||
|
|
||||||
Args:
|
|
||||||
model_config: Model configuration.
|
|
||||||
scheduler_config: Scheduler configuration.
|
|
||||||
mm_registry: Provides information about the token cost.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- Compute budget for encoder execution, in unit of number of tokens
|
- Compute budget for encoder execution, measured in number of tokens
|
||||||
in the input sequence.
|
from the input sequence.
|
||||||
- Space budget for encoder cache size, in unit of number of tokens
|
- Space budget for encoder cache size, measured in number of tokens
|
||||||
in the input sequence.
|
from the input sequence.
|
||||||
"""
|
"""
|
||||||
if mm_registry.supports_multimodal_inputs(model_config):
|
if mm_registry.supports_multimodal_inputs(model_config):
|
||||||
max_tokens_by_modality = mm_registry \
|
max_tokens_by_modality = mm_registry \
|
||||||
@ -231,10 +294,10 @@ def compute_mm_encoder_budget(
|
|||||||
non-text modality.
|
non-text modality.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- Compute budget for encoder execution, in unit of number of tokens
|
- Compute budget for encoder execution, measured in number of tokens
|
||||||
in the input sequence.
|
from the input sequence.
|
||||||
- Space budget for encoder cache size, in unit of number of tokens
|
- Space budget for encoder cache size, measured in number of tokens
|
||||||
in the input sequence.
|
from the input sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not max_tokens_by_modality:
|
if not max_tokens_by_modality:
|
||||||
|
|||||||
@ -143,9 +143,9 @@ class SchedulerOutput:
|
|||||||
# steps. This is used to notify the workers about the finished requests
|
# steps. This is used to notify the workers about the finished requests
|
||||||
# so that they can free the cached states for those requests.
|
# so that they can free the cached states for those requests.
|
||||||
finished_req_ids: set[str]
|
finished_req_ids: set[str]
|
||||||
# list of (req_id, encoder_input_index) tuples.
|
# list of mm_hash strings associated with the encoder outputs to be
|
||||||
# Used to free the encoder cache.
|
# freed from the encoder cache.
|
||||||
free_encoder_input_ids: list[tuple[str, int]]
|
free_encoder_mm_hashes: list[str]
|
||||||
|
|
||||||
# Dict of request ids to their index within the batch
|
# Dict of request ids to their index within the batch
|
||||||
# for filling the next token bitmask
|
# for filling the next token bitmask
|
||||||
|
|||||||
@ -252,6 +252,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
preempted_req = self.running.pop()
|
preempted_req = self.running.pop()
|
||||||
|
|
||||||
self.kv_cache_manager.free(preempted_req)
|
self.kv_cache_manager.free(preempted_req)
|
||||||
|
self.encoder_cache_manager.free(preempted_req)
|
||||||
preempted_req.status = RequestStatus.PREEMPTED
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
preempted_req.num_computed_tokens = 0
|
preempted_req.num_computed_tokens = 0
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
@ -550,7 +551,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
# It contains the request IDs that are finished in between
|
# It contains the request IDs that are finished in between
|
||||||
# the previous and the current steps.
|
# the previous and the current steps.
|
||||||
finished_req_ids=self.finished_req_ids,
|
finished_req_ids=self.finished_req_ids,
|
||||||
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||||
|
get_freed_mm_hashes(),
|
||||||
structured_output_request_ids=structured_output_request_ids,
|
structured_output_request_ids=structured_output_request_ids,
|
||||||
grammar_bitmask=grammar_bitmask,
|
grammar_bitmask=grammar_bitmask,
|
||||||
)
|
)
|
||||||
@ -698,7 +700,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
# in the decoder's KV cache.
|
# in the decoder's KV cache.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.encoder_cache_manager.has_cache(request, i):
|
if self.encoder_cache_manager.check_and_update_cache(request, i):
|
||||||
# The encoder input is already computed and cached.
|
# The encoder input is already computed and cached.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -712,8 +714,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_new_tokens = start_pos - num_computed_tokens
|
num_new_tokens = start_pos - num_computed_tokens
|
||||||
break
|
break
|
||||||
|
|
||||||
if (not self.encoder_cache_manager.can_allocate(request, i)
|
if not self.encoder_cache_manager.try_allocate(
|
||||||
or num_encoder_tokens > encoder_budget):
|
request, i, encoder_budget):
|
||||||
# The encoder cache is full or the encoder budget is exhausted.
|
# The encoder cache is full or the encoder budget is exhausted.
|
||||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||||
# be processed altogether, as the encoder usually uses
|
# be processed altogether, as the encoder usually uses
|
||||||
|
|||||||
@ -33,6 +33,7 @@ class CachedRequestState:
|
|||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
mm_kwargs: list[MultiModalKwargsItem]
|
mm_kwargs: list[MultiModalKwargsItem]
|
||||||
mm_positions: list[PlaceholderRange]
|
mm_positions: list[PlaceholderRange]
|
||||||
|
mm_hashes: list[str]
|
||||||
sampling_params: Optional[SamplingParams]
|
sampling_params: Optional[SamplingParams]
|
||||||
pooling_params: Optional[PoolingParams]
|
pooling_params: Optional[PoolingParams]
|
||||||
generator: Optional[torch.Generator]
|
generator: Optional[torch.Generator]
|
||||||
|
|||||||
@ -176,8 +176,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.attn_groups: list[list[AttentionGroup]] = []
|
self.attn_groups: list[list[AttentionGroup]] = []
|
||||||
# self.kv_cache_config: KVCacheConfig
|
# self.kv_cache_config: KVCacheConfig
|
||||||
|
|
||||||
# req_id -> (input_id -> encoder_output)
|
# mm_hash -> encoder_output
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
self.use_aux_hidden_state_outputs = False
|
self.use_aux_hidden_state_outputs = False
|
||||||
# Set up speculative decoding.
|
# Set up speculative decoding.
|
||||||
@ -436,7 +436,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Remove finished requests from the cached states.
|
# Remove finished requests from the cached states.
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.requests.pop(req_id, None)
|
self.requests.pop(req_id, None)
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
# Remove the finished requests from the persistent batch.
|
# Remove the finished requests from the persistent batch.
|
||||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||||
@ -447,12 +446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.input_batch.remove_request(req_id)
|
self.input_batch.remove_request(req_id)
|
||||||
|
|
||||||
# Free the cached encoder outputs.
|
# Free the cached encoder outputs.
|
||||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||||
encoder_outputs = self.encoder_cache.get(req_id)
|
self.encoder_cache.pop(mm_hash, None)
|
||||||
if encoder_outputs is not None:
|
|
||||||
encoder_outputs.pop(input_id, None)
|
|
||||||
if not encoder_outputs:
|
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
|
|
||||||
# Remove the unscheduled requests from the persistent batch.
|
# Remove the unscheduled requests from the persistent batch.
|
||||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||||
@ -496,6 +491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
mm_kwargs=new_req_data.mm_kwargs,
|
mm_kwargs=new_req_data.mm_kwargs,
|
||||||
mm_positions=new_req_data.mm_positions,
|
mm_positions=new_req_data.mm_positions,
|
||||||
|
mm_hashes=new_req_data.mm_hashes,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
@ -1161,17 +1157,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||||
if not scheduled_encoder_inputs:
|
if not scheduled_encoder_inputs:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_kwargs = list[MultiModalKwargsItem]()
|
mm_kwargs = list[MultiModalKwargsItem]()
|
||||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
# list of tuple (mm_hash, position_info)
|
||||||
|
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
for mm_input_id in encoder_input_ids:
|
for mm_input_id in encoder_input_ids:
|
||||||
|
mm_hash = req_state.mm_hashes[mm_input_id]
|
||||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||||
req_ids_pos.append(
|
mm_hashes_pos.append(
|
||||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||||
|
|
||||||
# Batch mm inputs as much as we can: if a request in the batch has
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
# multiple modalities or a different modality than the previous one,
|
# multiple modalities or a different modality than the previous one,
|
||||||
@ -1204,15 +1201,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for output in curr_group_outputs:
|
for output in curr_group_outputs:
|
||||||
encoder_outputs.append(output)
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
# Cache the encoder outputs.
|
# Cache the encoder outputs by mm_hash
|
||||||
for (req_id, input_id, pos_info), output in zip(
|
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||||
req_ids_pos,
|
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
||||||
encoder_outputs,
|
|
||||||
):
|
|
||||||
if req_id not in self.encoder_cache:
|
|
||||||
self.encoder_cache[req_id] = {}
|
|
||||||
|
|
||||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
|
||||||
output,
|
output,
|
||||||
is_embed=pos_info.is_embed,
|
is_embed=pos_info.is_embed,
|
||||||
)
|
)
|
||||||
@ -1230,6 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_computed_tokens = \
|
num_computed_tokens = \
|
||||||
req_state.num_computed_tokens + shift_computed_tokens
|
req_state.num_computed_tokens + shift_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
|
mm_hashes = req_state.mm_hashes
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info.offset
|
||||||
num_encoder_tokens = pos_info.length
|
num_encoder_tokens = pos_info.length
|
||||||
@ -1249,11 +1241,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||||
end_idx = min(
|
end_idx = min(
|
||||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||||
num_encoder_tokens)
|
num_encoder_tokens,
|
||||||
|
)
|
||||||
assert start_idx < end_idx
|
assert start_idx < end_idx
|
||||||
assert req_id in self.encoder_cache
|
|
||||||
assert i in self.encoder_cache[req_id]
|
mm_hash = mm_hashes[i]
|
||||||
encoder_output = self.encoder_cache[req_id][i]
|
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||||
|
assert encoder_output is not None,\
|
||||||
|
f"Encoder cache miss for {mm_hash}."
|
||||||
|
|
||||||
if (is_embed := pos_info.is_embed) is not None:
|
if (is_embed := pos_info.is_embed) is not None:
|
||||||
is_embed = is_embed[start_idx:end_idx]
|
is_embed = is_embed[start_idx:end_idx]
|
||||||
|
|||||||
@ -208,8 +208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Lazy initialization
|
# Lazy initialization
|
||||||
self.model: nn.Module # Set after load_model
|
self.model: nn.Module # Set after load_model
|
||||||
self.kv_caches: list[torch.Tensor] = []
|
self.kv_caches: list[torch.Tensor] = []
|
||||||
# req_id -> (input_id -> encoder_output)
|
# mm_hash -> encoder_output
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: dict[str, CachedRequestState] = {}
|
self.requests: dict[str, CachedRequestState] = {}
|
||||||
@ -342,7 +342,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Remove finished requests from the cached states.
|
# Remove finished requests from the cached states.
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.requests.pop(req_id, None)
|
self.requests.pop(req_id, None)
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
|
|
||||||
# Remove the finished requests from the persistent batch.
|
# Remove the finished requests from the persistent batch.
|
||||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||||
@ -357,12 +356,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
removed_req_indices.append(req_index)
|
removed_req_indices.append(req_index)
|
||||||
|
|
||||||
# Free the cached encoder outputs.
|
# Free the cached encoder outputs.
|
||||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||||
encoder_outputs = self.encoder_cache.get(req_id)
|
self.encoder_cache.pop(mm_hash, None)
|
||||||
if encoder_outputs is not None:
|
|
||||||
encoder_outputs.pop(input_id, None)
|
|
||||||
if not encoder_outputs:
|
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
|
|
||||||
# Remove the unscheduled requests from the persistent batch.
|
# Remove the unscheduled requests from the persistent batch.
|
||||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||||
@ -394,6 +389,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
mm_kwargs=new_req_data.mm_kwargs,
|
mm_kwargs=new_req_data.mm_kwargs,
|
||||||
mm_positions=new_req_data.mm_positions,
|
mm_positions=new_req_data.mm_positions,
|
||||||
|
mm_hashes=new_req_data.mm_hashes,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
pooling_params=None,
|
pooling_params=None,
|
||||||
generator=None,
|
generator=None,
|
||||||
@ -845,14 +841,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_kwargs = list[MultiModalKwargsItem]()
|
mm_kwargs = list[MultiModalKwargsItem]()
|
||||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
# List of tuple (mm_hash, pos_info)
|
||||||
|
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
for mm_input_id in encoder_input_ids:
|
for mm_input_id in encoder_input_ids:
|
||||||
|
mm_hash = req_state.mm_hashes[mm_input_id]
|
||||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||||
req_ids_pos.append(
|
mm_hashes_pos.append(
|
||||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||||
|
|
||||||
# Batch mm inputs as much as we can: if a request in the batch has
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
# multiple modalities or a different modality than the previous one,
|
# multiple modalities or a different modality than the previous one,
|
||||||
@ -895,15 +893,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# NOTE (NickLucche) here we diverge from logic in other runners, as we
|
# NOTE (NickLucche) here we diverge from logic in other runners, as we
|
||||||
# assume to only have whole mm items to process. Hence we avoid the
|
# assume to only have whole mm items to process. Hence we avoid the
|
||||||
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
|
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
|
||||||
for (req_id, input_id, pos_info), output in zip(
|
for (mm_hash, pos_info), output in zip(
|
||||||
req_ids_pos,
|
mm_hashes_pos,
|
||||||
encoder_outputs,
|
encoder_outputs,
|
||||||
):
|
):
|
||||||
if req_id not in self.encoder_cache:
|
if req_id not in self.encoder_cache:
|
||||||
self.encoder_cache[req_id] = {}
|
self.encoder_cache[req_id] = {}
|
||||||
assert pos_info.is_embed is None, "Expected all positions to be"\
|
assert pos_info.is_embed is None, "Expected all positions to be"\
|
||||||
" contiguous and embeddings."
|
" contiguous and embeddings."
|
||||||
self.encoder_cache[req_id][input_id] = output
|
self.encoder_cache[mm_hash] = output
|
||||||
|
|
||||||
def _gather_mm_embeddings(
|
def _gather_mm_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -916,6 +914,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
num_computed_tokens = req_state.num_computed_tokens
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
|
mm_hashes = req_state.mm_hashes
|
||||||
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
|
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
|
||||||
# NOTE (NickLucche) here we diverge from logic in other runners, as
|
# NOTE (NickLucche) here we diverge from logic in other runners, as
|
||||||
# we assume to only have whole mm items to process. Hence we avoid
|
# we assume to only have whole mm items to process. Hence we avoid
|
||||||
@ -936,11 +935,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# in the decoder's KV cache.
|
# in the decoder's KV cache.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert req_id in self.encoder_cache
|
mm_hash = mm_hashes[i]
|
||||||
assert i in self.encoder_cache[req_id]
|
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||||
|
assert encoder_output is not None,\
|
||||||
|
f"Encoder cache miss for {mm_hash}."
|
||||||
assert pos_info.is_embed is None, "Expected all positions to"\
|
assert pos_info.is_embed is None, "Expected all positions to"\
|
||||||
" be contiguous and embeddings."
|
" be contiguous and embeddings."
|
||||||
encoder_output = self.encoder_cache[req_id][i]
|
encoder_output = self.encoder_cache[mm_hash]
|
||||||
mm_embeds.append(encoder_output)
|
mm_embeds.append(encoder_output)
|
||||||
return mm_embeds
|
return mm_embeds
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user