From f5f51e5931ffd99afe69696b60765b88d3eb13f2 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 16 Dec 2025 14:18:17 -0800 Subject: [PATCH] [Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475) Signed-off-by: Roger Wang Co-authored-by: Sun Kim --- .../multimodal/processing/test_mllama4.py | 4 +- tests/multimodal/test_utils.py | 92 +++++++++++++++++++ tests/v1/core/test_encoder_cache_manager.py | 79 +++++++++++++++- .../unit/test_ec_example_connector.py | 2 +- .../ec_connector/example_connector.py | 2 +- vllm/model_executor/models/qwen3_vl.py | 8 +- vllm/multimodal/inputs.py | 39 +++++++- vllm/multimodal/profiling.py | 32 ++----- vllm/multimodal/registry.py | 2 +- vllm/v1/core/encoder_cache_manager.py | 80 ++++++++-------- vllm/v1/core/sched/scheduler.py | 35 +++++-- vllm/v1/request.py | 6 +- vllm/v1/worker/gpu_model_runner.py | 49 +++------- vllm/v1/worker/utils.py | 6 ++ 14 files changed, 306 insertions(+), 130 deletions(-) diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index e5ff2d1391b62..325159965c803 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int): total_num_patches.item() + num_tiles.item() + 3 ) # image start, image, image end - profiled_tokens = profiler.get_mm_max_contiguous_tokens( + profiled_tokens = profiler.get_mm_max_tokens( max_model_len, mm_counts=mm_counts, ) - assert total_tokens == profiled_tokens["image"] + assert total_num_patches == profiled_tokens["image"] assert total_tokens == sum( placeholder.length for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 636cd0ffd445e..02bb1f769baad 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory import numpy as np import pytest +import torch from PIL import Image, ImageChops from vllm.multimodal.image import convert_image_mode @@ -410,6 +411,97 @@ def test_argsort_mm_positions(case): assert modality_idxs == expected_modality_idxs +@pytest.mark.parametrize( + "is_embed,expected", + [ + (None, 5), + (torch.tensor([True, True, True, True, True]), 5), + (torch.tensor([False, False, False, False, False]), 0), + (torch.tensor([True, False, True, False, True]), 3), + (torch.tensor([True]), 1), + ], +) +def test_placeholder_range_get_num_embeds(is_embed, expected): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed) + assert pr.get_num_embeds == expected + + +@pytest.mark.parametrize( + "is_embed,expected", + [ + (None, None), + ( + torch.tensor([False, True, False, True, True]), + torch.tensor([0, 1, 1, 2, 3]), + ), + (torch.tensor([True, True, True]), torch.tensor([1, 2, 3])), + ], +) +def test_placeholder_range_embeds_cumsum(is_embed, expected): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed) + + if expected is None: + assert pr.embeds_cumsum is None + return + + assert torch.equal(pr.embeds_cumsum, expected) + # cached_property should return the same object on repeated access + assert pr.embeds_cumsum is pr.embeds_cumsum + + +@pytest.mark.parametrize( + "is_embed,start_idx,end_idx,expected", + [ + (None, 2, 4, (2, 4)), + ( + torch.tensor([False, True, False, True, True]), + 3, + 5, + (1, 3), + ), + ( + torch.tensor([False, True, False, True, True]), + 0, + 2, + (0, 1), + ), + ( + torch.tensor([True, False, True, False]), + 2, + 2, + (1, 1), + ), + ], +) +def test_placeholder_range_get_embeds_indices_in_range( + is_embed, start_idx, end_idx, expected +): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed) + assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected + + +@pytest.mark.parametrize( + "offset,is_embed,expected", + [ + (0, None, [(0, 4)]), + ( + 2, + torch.tensor([False, True, False, True, True]), + [(3, 3), (5, 6)], + ), + (0, torch.tensor([True, True, True, True]), [(0, 3)]), + (0, torch.tensor([False, False, False, False]), []), + ], +) +def test_placeholder_range_extract_embeds_range(offset, is_embed, expected): + length = len(is_embed) if is_embed is not None else 5 + pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed) + assert pr.extract_embeds_range() == expected + + @pytest.mark.asyncio @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("num_frames", [-1, 32, 1800]) diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py index 8a52b5bd78977..511ff48c401ca 100644 --- a/tests/v1/core/test_encoder_cache_manager.py +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.v1.core.encoder_cache_manager import EncoderCacheManager @@ -23,7 +24,7 @@ class MockRequest: ) self.mm_features.append(feature) - def get_num_encoder_tokens(self, input_id: int) -> int: + def get_num_encoder_embeds(self, input_id: int) -> int: return self._token_counts[input_id] @@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit(): num_tokens_to_schedule = 0 assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) - num_tokens_to_schedule += req.get_num_encoder_tokens(0) - compute_budget -= req.get_num_encoder_tokens(0) + num_tokens_to_schedule += req.get_num_encoder_embeds(0) + compute_budget -= req.get_num_encoder_embeds(0) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) @@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit(): compute_budget = 10 num_tokens_to_schedule = 0 assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) - num_tokens_to_schedule += req.get_num_encoder_tokens(0) - compute_budget -= req.get_num_encoder_tokens(0) + num_tokens_to_schedule += req.get_num_encoder_embeds(0) + compute_budget -= req.get_num_encoder_embeds(0) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) + + +def test_encoder_cache_with_is_embed_mask(): + class MockRequestWithMask(MockRequest): + def get_num_encoder_embeds(self, input_id: int) -> int: + return self.mm_features[input_id].mm_position.get_num_embeds + + is_embed = torch.zeros(100, dtype=torch.bool) + is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True + + request = MockRequestWithMask("r1", ["img1"], [100]) + request.mm_features[0] = MultiModalFeatureSpec( + data=None, + modality="image", + identifier="img1", + mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed), + ) + + manager = EncoderCacheManager(cache_size=100) + manager.allocate(request, 0) + + assert manager.num_free_slots == 92 + assert "img1" in manager.cached + + old_size = 100 + new_size = request.mm_features[0].mm_position.get_num_embeds + assert new_size == 8 + savings_ratio = old_size / new_size + assert savings_ratio == 12.5 + + +def test_encoder_cache_mask_based_retrieval(): + class MockRequestWithMask(MockRequest): + def get_num_encoder_embeds(self, input_id: int) -> int: + return self.mm_features[input_id].mm_position.get_num_embeds + + is_embed = torch.tensor( + [False, False, True, True, False, True, True, True, False, False] + ) + + request = MockRequestWithMask("r1", ["img1"], [10]) + request.mm_features[0] = MultiModalFeatureSpec( + data=None, + modality="image", + identifier="img1", + mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed), + ) + + manager = EncoderCacheManager(cache_size=50) + manager.allocate(request, 0) + + assert request.mm_features[0].mm_position.get_num_embeds == 5 + + start_idx = 2 + end_idx = 8 + num_embeds_before = is_embed[:start_idx].sum().item() + num_embeds_in_range = is_embed[start_idx:end_idx].sum().item() + + assert num_embeds_before == 0 + assert num_embeds_in_range == 5 + + start_idx = 0 + end_idx = 5 + num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0 + num_embeds_in_range = is_embed[start_idx:end_idx].sum().item() + + assert num_embeds_before == 0 + assert num_embeds_in_range == 2 diff --git a/tests/v1/ec_connector/unit/test_ec_example_connector.py b/tests/v1/ec_connector/unit/test_ec_example_connector.py index 7e9eb21310031..9ed82e1cef823 100644 --- a/tests/v1/ec_connector/unit/test_ec_example_connector.py +++ b/tests/v1/ec_connector/unit/test_ec_example_connector.py @@ -38,7 +38,7 @@ class MockRequest: ) self.mm_features.append(feature) - def get_num_encoder_tokens(self, input_id: int) -> int: + def get_num_encoder_embeds(self, input_id: int) -> int: assert input_id < len(self._token_counts) return self._token_counts[input_id] diff --git a/vllm/distributed/ec_transfer/ec_connector/example_connector.py b/vllm/distributed/ec_transfer/ec_connector/example_connector.py index 5f2eff5a8e6a8..c9aad9e9fc8f3 100644 --- a/vllm/distributed/ec_transfer/ec_connector/example_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/example_connector.py @@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase): Update ECConnector state after encoder cache allocation. """ mm_hash = request.mm_features[index].identifier - num_encoder_token = request.get_num_encoder_tokens(index) + num_encoder_token = request.get_num_encoder_embeds(index) # Insert mm_hash only if this block has not been recorded yet. self._mm_datas_need_loads[mm_hash] = num_encoder_token diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index c0589986d1fe8..4838f68e06f70 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): mm_counts: Mapping[str, int], ) -> int: target_width, target_height = self.get_image_size_with_most_features() - video_soft_tokens = self.get_num_video_tokens( + num_video_soft_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) - - # NOTE: By default in Qwen3-VL, one video token is converted to - # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 - formatted_video_soft_tokens = video_soft_tokens * 12.5 - return int(formatted_video_soft_tokens) + return num_video_soft_tokens def _calculate_timestamps( self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 6b1cbbe24e2e7..fa69818a7b1f8 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass -from functools import partial +from functools import cached_property, partial from itertools import accumulate from typing import ( TYPE_CHECKING, @@ -169,11 +169,42 @@ class PlaceholderRange: between `offset` and `offset + length` to assign embeddings to. """ - def get_num_embeds(self) -> int: + @cached_property + def embeds_cumsum(self) -> torch.Tensor | None: if self.is_embed is None: + return None + + return self.is_embed.cumsum(dim=0) + + @cached_property + def get_num_embeds(self) -> int: + if self.embeds_cumsum is None: return self.length - return int(self.is_embed.sum().item()) + return int(self.embeds_cumsum[-1]) + + def get_embeds_indices_in_range( + self, start_idx: int, end_idx: int + ) -> tuple[int, int]: + """ + Returns the starting and ending indices of the embeddings of encoder outputs + in the range of [start_idx, end_idx) in the placeholders. + + For example, given: + PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True]) + + If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get + the second and the third embeddings from the encoder output. + """ + if self.embeds_cumsum is None: + return start_idx, end_idx + + embeds_start_idx = ( + int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0 + ) + embeds_end_idx = int(self.embeds_cumsum[end_idx - 1]) + + return embeds_start_idx, embeds_end_idx def extract_embeds_range(self) -> list[tuple[int, int]]: """Extract the start and end indices of the embedded region in prompt. @@ -188,7 +219,7 @@ class PlaceholderRange: Returns full placeholder range if `is_embed` is `None`. """ if self.is_embed is None: - return [(self.offset, self.offset + self.length)] + return [(self.offset, self.offset + self.length - 1)] mask_i = self.is_embed.int() starts = torch.nonzero( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index cb70041e9744f..a690948f759e9 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]): def _get_mm_num_tokens( self, mm_inputs: MultiModalInputs, - mm_embeddings_only: bool = True, ) -> Mapping[str, int]: placeholders_by_modality = mm_inputs["mm_placeholders"] return { - modality: sum( - item.get_num_embeds() if mm_embeddings_only else item.length - for item in placeholders - ) + modality: sum(item.get_num_embeds for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } @@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]): multi_modal_placeholders=mm_inputs["mm_placeholders"], ) - def _get_mm_max_tokens( + def get_mm_max_tokens( self, seq_len: int, mm_counts: Mapping[str, int] | None = None, - mm_embeddings_only: bool = True, ) -> Mapping[str, int]: + """ + Returns the maximum number of embeddings per item of each modality, excluding + any break/text tokens in-between multimodal embeddings/encoder outputs. + """ if mm_counts is None: mm_counts = self.get_mm_limits() @@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]): } mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) - - def get_mm_max_contiguous_tokens( - self, - seq_len: int, - mm_counts: Mapping[str, int] | None = None, - ) -> Mapping[str, int]: - """ - Returns the maximum length of the multimodal (image placeholders+text) - tokens, including any break/text tokens in-between image embeddings. - - ` [IMG] [IMG] [IMG] [IMG] [IMG] [IMG] ` - Returns 9, even when the number of image embeddings is 6. - - This is important to take into account when profiling and - initializing the encoder cache size. - """ - return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) + return self._get_mm_num_tokens(mm_inputs) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 00a84f9dec4f7..1e7fe8648ab71 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -164,7 +164,7 @@ class MultiModalRegistry: profiler.get_mm_limits() if profiler_limits is None else profiler_limits ) - return profiler.get_mm_max_contiguous_tokens( + return profiler.get_mm_max_tokens( seq_len, {modality: 1 for modality, limit in profiler_limits.items() if limit > 0}, ) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 50f738713590b..d73c05d2cf80b 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -39,20 +39,26 @@ class EncoderCacheManager: space for new embeddings. Oldest cached embeddings with no request referenced will be first evicted. + NOTE: The EncoderCacheManager operates on the level of multimodal embeddings + instead of encoder tokens (i.e. all tokens that represent the multimodal data + in the input sequence). This means all break/text tokens in-between multimodal + embeddings are not considered with respect to the cache size and the number + of free slots. + Args: cache_size: Limit the size of the cache, measured by the number of - tokens from the input sequence. + encoder embeddings from the input sequence. Attributes: - cache_size: Total cache capacity in encoder tokens. - num_free_slots: Current available cache capacity in encoder tokens. + cache_size: Total cache capacity in encoder embeddings. + num_free_slots: Current available cache capacity in encoder embeddings. num_freeable_slots: Capacity that can be immediately reclaimed by - evicting entries with zero references (in encoder tokens). + evicting entries with zero references (in encoder embeddings). cached: Mapping from mm_hash to a set of request IDs that currently reference the cached entry. If the set is empty, the entry exists but is not referenced by any request and is eligible for reclamation. - freeable: List of tuples (mm_hash, num_tokens) representing entries + freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries whose no current running request is needed and that can be freed to make space when needed. freed: List of mm_hash strings that were actually evicted since the @@ -67,7 +73,7 @@ class EncoderCacheManager: # mm_hash of mm_data => ids of requests that reference the mm_data self.cached: dict[str, set[str]] = {} - # mm_hash of mm_data => num_encoder_tokens of the mm_data + # mm_hash of mm_data => num_encoder_embeds of the mm_data self.freeable: OrderedDict[str, int] = OrderedDict() self.freed: list[str] = [] @@ -93,8 +99,8 @@ class EncoderCacheManager: # Cached but currently not referenced by any request if not self.cached[mm_hash]: - num_tokens = self.freeable.pop(mm_hash) - self.num_freeable_slots -= num_tokens + num_encoder_embeds = self.freeable.pop(mm_hash) + self.num_freeable_slots -= num_encoder_embeds self.cached[mm_hash].add(request.request_id) return True @@ -104,7 +110,7 @@ class EncoderCacheManager: request: Request, input_id: int, encoder_compute_budget: int, - num_tokens_to_schedule: int, + num_embeds_to_schedule: int, ) -> bool: """Check if there's sufficient cache space for a multimodal input. If there is, return True and update EncoderCacheManager state. @@ -121,9 +127,9 @@ class EncoderCacheManager: Args: request: The request containing the multimodal input. input_id: Index of the multimodal input within the request. - encoder_compute_budget: Number of encoder tokens allowed to be + encoder_compute_budget: Number of encoder embeddings allowed to be computed when this method is invoked. - num_tokens_to_schedule: Number of tokens already scheduled to be + num_embeds_to_schedule: Number of encoder embeddings already scheduled to be allocated with cache space when this method is invoked. Returns: @@ -134,30 +140,30 @@ class EncoderCacheManager: Note: This method does not allocate physical memory for the encoder output but only the state of EncoderCacheManager. """ - num_tokens = request.get_num_encoder_tokens(input_id) + num_embeds = request.get_num_encoder_embeds(input_id) # Not enough compute budget - if num_tokens > encoder_compute_budget: + if num_embeds > encoder_compute_budget: return False - num_tokens += num_tokens_to_schedule + num_embeds += num_embeds_to_schedule # Enough free slots - if num_tokens <= self.num_free_slots: + if num_embeds <= self.num_free_slots: return True # Not enough reclaimable slots - if num_tokens > self.num_freeable_slots: + if num_embeds > self.num_freeable_slots: return False # Not enough free slots but enough reclaimable slots # NOTE: Eviction takes place here, but physical memory is not freed # until model runner is notified by the scheduler output. - while num_tokens > self.num_free_slots: - mm_hash, num_free_token = self.freeable.popitem(last=False) + while num_embeds > self.num_free_slots: + mm_hash, num_free_embeds = self.freeable.popitem(last=False) del self.cached[mm_hash] self.freed.append(mm_hash) - self.num_free_slots += num_free_token + self.num_free_slots += num_free_embeds return True def allocate(self, request: Request, input_id: int) -> None: @@ -176,16 +182,16 @@ class EncoderCacheManager: if mm_hash not in self.cached: self.cached[mm_hash] = set() - num_encoder_tokens = request.get_num_encoder_tokens(input_id) + num_encoder_embeds = request.get_num_encoder_embeds(input_id) # NOTE: Encoder cache should always have enough space for encoder inputs # that are scheduled since eviction takes place at can_allocate(). - assert self.num_free_slots >= num_encoder_tokens - assert self.num_freeable_slots >= num_encoder_tokens + assert self.num_free_slots >= num_encoder_embeds + assert self.num_freeable_slots >= num_encoder_embeds self.cached[mm_hash].add(request_id) - self.num_free_slots -= num_encoder_tokens - self.num_freeable_slots -= num_encoder_tokens + self.num_free_slots -= num_encoder_embeds + self.num_freeable_slots -= num_encoder_embeds def get_cached_input_ids(self, request: Request) -> set[int]: """Get all cached multimodal input IDs for a request. @@ -206,7 +212,7 @@ class EncoderCacheManager: When the reference set for the corresponding `mm_hash` becomes empty, the entry is appended to `freeable` and `num_freeable_slots` is - increased by the number of encoder tokens for that input. + increased by the number of encoder embeddings for that input. The entry is NOT physically freed until capacity is needed (e.g., by `can_allocate`). @@ -218,9 +224,9 @@ class EncoderCacheManager: return self.cached[mm_hash].discard(req_id) if not self.cached[mm_hash]: - num_tokens = request.get_num_encoder_tokens(input_id) - self.freeable[mm_hash] = num_tokens - self.num_freeable_slots += num_tokens + num_encoder_embeds = request.get_num_encoder_embeds(input_id) + self.freeable[mm_hash] = num_encoder_embeds + self.num_freeable_slots += num_encoder_embeds def free(self, request: Request) -> None: """Free all encoder input cache reference held by *request*. @@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager): request: Request, input_id: int, encoder_compute_budget: int, - num_tokens_to_schedule: int, + num_embeds_to_schedule: int, ) -> bool: - num_tokens = request.get_num_encoder_tokens(input_id) + num_encoder_embeds = request.get_num_encoder_embeds(input_id) # Not enough compute budget - if num_tokens > encoder_compute_budget: + if num_encoder_embeds > encoder_compute_budget: return False - num_tokens += num_tokens_to_schedule + num_encoder_embeds += num_embeds_to_schedule # Enough free slots - return num_tokens <= self.num_free_slots + return num_encoder_embeds <= self.num_free_slots def allocate(self, request: Request, input_id: int) -> None: - num_encoder_tokens = request.get_num_encoder_tokens(input_id) - self.num_free_slots -= num_encoder_tokens + num_encoder_embeds = request.get_num_encoder_embeds(input_id) + self.num_free_slots -= num_encoder_embeds mm_hash = request.mm_features[input_id].identifier self.freed.append(mm_hash) @@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager): return freed def free_encoder_input(self, request: Request, input_id: int) -> None: - num_tokens = request.get_num_encoder_tokens(input_id) - self.num_free_slots += num_tokens + num_encoder_embeds = request.get_num_encoder_embeds(input_id) + self.num_free_slots += num_encoder_embeds diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 754e0b9d08316..8e835ad096405 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface): if preempted_encoder_inputs: # Restore encoder compute budget if the preempted # request had encoder inputs scheduled in this step. - num_tokens_to_restore = sum( - preempted_req.get_num_encoder_tokens(i) + num_embeds_to_restore = sum( + preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs ) - encoder_compute_budget += num_tokens_to_restore + encoder_compute_budget += num_embeds_to_restore req_index -= 1 else: preempted_req = self.running.pop() @@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface): # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. mm_hashes_to_schedule = set() - num_tokens_to_schedule = 0 + num_embeds_to_schedule = 0 for i, mm_feature in enumerate(mm_features): start_pos = mm_feature.mm_position.offset num_encoder_tokens = mm_feature.mm_position.length + num_encoder_embeds = mm_feature.mm_position.get_num_embeds # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface): ): num_new_tokens = start_pos - num_computed_tokens break - if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, num_tokens_to_schedule + request, i, encoder_compute_budget, num_embeds_to_schedule ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should @@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface): num_new_tokens = 0 break + # Calculate the number of embeddings to schedule in the current range + # of scheduled encoder placholder tokens. + start_idx_rel = max(0, num_computed_tokens - start_pos) + end_idx_rel = min( + num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos + ) + curr_embeds_start, curr_embeds_end = ( + mm_feature.mm_position.get_embeds_indices_in_range( + start_idx_rel, + end_idx_rel, + ) + ) + # There's no embeddings in the current range of encoder placeholder tokens + # so we can skip the encoder input. + if curr_embeds_end - curr_embeds_start == 0: + continue + if self.ec_connector is not None and remote_cache_has_item[i]: mm_hashes_to_schedule.add(request.mm_features[i].identifier) external_load_encoder_input.append(i) - num_tokens_to_schedule += num_encoder_tokens + num_embeds_to_schedule += num_encoder_embeds continue - num_tokens_to_schedule += num_encoder_tokens - encoder_compute_budget -= num_encoder_tokens + num_embeds_to_schedule += num_encoder_embeds + encoder_compute_budget -= num_encoder_embeds mm_hashes_to_schedule.add(request.mm_features[i].identifier) encoder_inputs_to_schedule.append(i) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index a775e840e841c..f33059b80b894 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -209,10 +209,10 @@ class Request: def get_finished_reason(self) -> FinishReason | None: return RequestStatus.get_finished_reason(self.status) - def get_num_encoder_tokens(self, input_id: int) -> int: + def get_num_encoder_embeds(self, input_id: int) -> int: assert input_id < len(self.mm_features) - num_tokens = self.mm_features[input_id].mm_position.length - return num_tokens + num_embeds = self.mm_features[input_id].mm_position.get_num_embeds + return num_embeds def record_event( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 179f713c4d86a..1db5bc99fff6c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -169,9 +169,7 @@ from .utils import ( MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders, ) if TYPE_CHECKING: @@ -2209,10 +2207,7 @@ class GPUModelRunner( # Cache the encoder outputs by mm_hash for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + self.encoder_cache[mm_hash] = output logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) @@ -2263,6 +2258,13 @@ class GPUModelRunner( num_encoder_tokens, ) assert start_idx < end_idx + curr_embeds_start, curr_embeds_end = ( + pos_info.get_embeds_indices_in_range(start_idx, end_idx) + ) + # If there are no embeddings in the current range, we skip + # gathering the embeddings. + if curr_embeds_start == curr_embeds_end: + continue mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) @@ -2270,16 +2272,14 @@ class GPUModelRunner( if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end] + else: + mm_embeds_item = encoder_output[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( True if is_embed is None else is_embed ) - - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: @@ -4508,31 +4508,8 @@ class GPUModelRunner( dummy_encoder_outputs, expected_num_items=max_mm_items_per_batch, ) - - # NOTE: This happens when encoder cache needs to store - # the embeddings that encoder outputs are scattered onto. - # In this case we create dummy embeddings of size - # (max_tokens_for_modality, hidden_size) and scatter - # encoder output into it. - encoder_output_shape = dummy_encoder_outputs[0].shape - max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[ - dummy_modality - ] - if encoder_output_shape[0] < max_mm_tokens_per_item: - encoder_hidden_size = encoder_output_shape[-1] - expanded_outputs = [] - for output in dummy_encoder_outputs: - expanded = output.new_zeros( - (max_mm_tokens_per_item, encoder_hidden_size) - ) - num_tokens = output.shape[0] - expanded[:num_tokens].copy_(output) - expanded_outputs.append(expanded) - - dummy_encoder_outputs = expanded_outputs - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + for i, output in enumerate(dummy_encoder_outputs): + self.encoder_cache[f"tmp_{i}"] = output # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states = self._dummy_run( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index e9c48223d58b9..2e8afec024ce9 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -4,10 +4,12 @@ from collections import defaultdict from dataclasses import dataclass, field import torch +from typing_extensions import deprecated from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config @@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec +logger = init_logger(__name__) + class MultiModalBudget: """Helper class to calculate budget information for multi-modal models.""" @@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs( ) +@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.") def scatter_mm_placeholders( embeds: torch.Tensor, is_embed: torch.Tensor | None, @@ -226,6 +231,7 @@ def scatter_mm_placeholders( return placeholders +@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.") def gather_mm_placeholders( placeholders: torch.Tensor, is_embed: torch.Tensor | None,