diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py index ae5b751f45a4b..4e3cace86be6a 100644 --- a/tests/v1/core/test_encoder_cache_manager.py +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.v1.core.encoder_cache_manager import EncoderCacheManager @@ -9,8 +10,17 @@ class MockRequest: def __init__(self, request_id, mm_hashes, token_counts): self.request_id = request_id - self.mm_hashes = mm_hashes self._token_counts = token_counts + self.mm_features = [] + for i, mm_hash in enumerate(mm_hashes): + feature = MultiModalFeatureSpec( + data=None, + modality="image", + identifier=mm_hash, + mm_position=PlaceholderRange(offset=0, + length=self._token_counts[i]), + ) + self.mm_features.append(feature) def get_num_encoder_tokens(self, input_id: int) -> int: return self._token_counts[input_id] diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 941aa0a77692c..c719e44acc9c2 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -64,9 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=PoolingParams(), block_ids=([0], ), # block_ids should be tuple[list[int]] diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 7031859078264..38f543c784866 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -203,9 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int): prompt_token_ids=prompt_token_ids, sampling_params=_create_sampling_params(), pooling_params=None, - mm_kwargs=[], - mm_positions=[], - mm_hashes=[], + mm_features=[], block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 8e27c14b7405b..5ebc00d573030 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -118,9 +118,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=None, block_ids=([0], ), diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 9f4da613f8b62..25460ed295d30 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -300,11 +300,12 @@ class SharedStorageConnector(KVConnectorBase_V1): total_need_load = 0 for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=False, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in new_req.mm_features]) total_need_load += 1 else: # NOTE: here, we set the store and load being exclusive, @@ -312,11 +313,12 @@ class SharedStorageConnector(KVConnectorBase_V1): # NOTE(rob): for this debug implementation, we only cache # the original prompt tokens. if not self._found_match_for_request(new_req): - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=True, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=True, + mm_hashes=[f.identifier for f in new_req.mm_features]) cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): @@ -341,11 +343,12 @@ class SharedStorageConnector(KVConnectorBase_V1): # of the block_ids for the request. block_ids = new_block_ids[0] - meta.add_request(token_ids=token_ids, - block_ids=block_ids, - block_size=self._block_size, - is_store=False, - mm_hashes=request.mm_hashes) + meta.add_request( + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in request.mm_features]) total_need_load += 1 assert total_need_load == len(self._requests_need_load) @@ -364,10 +367,10 @@ class SharedStorageConnector(KVConnectorBase_V1): """ num_tokens_to_check = align_to_block_size( len(request.prompt_token_ids) - 1, self._block_size) - foldername = self._generate_foldername_debug(torch.tensor( - request.prompt_token_ids)[:num_tokens_to_check], - request.mm_hashes, - create_folder=False) + foldername = self._generate_foldername_debug( + torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], + [f.identifier for f in request.mm_features], + create_folder=False) return os.path.exists(foldername) def _generate_foldername_debug( diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index bd2ec036834b2..eadea15a2e5e3 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -86,7 +86,7 @@ class EncoderCacheManager: Returns: True if the encoder output for this input is already cached """ - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier # Not cached at all if mm_hash not in self.cached: return False @@ -167,7 +167,7 @@ class EncoderCacheManager: This method assumes can_allocate() returned True for the same input. """ - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier request_id = request.request_id if mm_hash not in self.cached: self.cached[mm_hash] = set() @@ -193,8 +193,8 @@ class EncoderCacheManager: """ return { input_id - for input_id in range(len(request.mm_hashes)) - if request.mm_hashes[input_id] in self.cached + for input_id in range(len(request.mm_features)) + if request.mm_features[input_id].identifier in self.cached } def free_encoder_input(self, request: Request, input_id: int) -> None: @@ -208,7 +208,7 @@ class EncoderCacheManager: `can_allocate`). """ req_id = request.request_id - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier # The mm_hash not in cache or the req_id set is empty if not self.cached.get(mm_hash, None): return diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2c0eac3ddd79d..f939da8c5b5c3 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -418,9 +418,9 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_hashes) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return bool(request.mm_features) or (request.lora_request + is not None) or (request.cache_salt + is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -442,32 +442,28 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, """ extra_keys: list[Any] = [] - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes - if not mm_positions: + mm_features = request.mm_features + if not mm_features: return extra_keys, start_mm_idx - if mm_positions and len(mm_positions) != len(mm_hashes): - raise ValueError( - "The number of multi-modal positions and hashes must match. This " - "is likely because you did not enable MM hashing. " - "Please set `mm_processor_cache_gb > 0`.") - - # Note that we assume mm_positions is sorted by offset. + # Note that we assume mm_features are sorted by mm_position.offset. # We do not need to check all mm inputs if the start token index is out of # range. This usually happens in the late prefill phase and decoding phase. - if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx: + last_pos = mm_features[-1].mm_position + if last_pos.offset + last_pos.length < start_token_idx: return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. if start_mm_idx < 0: - assert -start_mm_idx <= len(mm_positions) - start_mm_idx = len(mm_positions) + start_mm_idx + assert -start_mm_idx <= len(mm_features) + start_mm_idx = len(mm_features) + start_mm_idx curr_mm_idx = start_mm_idx - while mm_positions and curr_mm_idx < len(mm_positions): - assert mm_hashes[curr_mm_idx] is not None - offset = mm_positions[curr_mm_idx].offset - length = mm_positions[curr_mm_idx].length + while mm_features and curr_mm_idx < len(mm_features): + mm_feature = mm_features[curr_mm_idx] + assert mm_feature.identifier is not None + offset = mm_feature.mm_position.offset + length = mm_feature.mm_position.length if end_token_idx > offset: if start_token_idx > offset + length: # This block has passed the current mm input. @@ -475,7 +471,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, continue # The block contains the current mm input. - extra_keys.append(mm_hashes[curr_mm_idx]) + extra_keys.append(mm_feature.identifier) if end_token_idx >= offset + length: # If this block contains the end of the current mm input, diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 9888f25735753..56ab396d6d937 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) from vllm.lora.request import LoRARequest - from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request @@ -27,9 +27,7 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_hashes: list[str] - mm_positions: list[PlaceholderRange] + mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] block_ids: tuple[list[int], ...] @@ -45,9 +43,7 @@ class NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - mm_kwargs=request.mm_kwargs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, + mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, block_ids=block_ids, @@ -59,9 +55,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," + f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," @@ -73,9 +67,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids_len={len(self.prompt_token_ids)}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," + f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index aa45f6669207d..c1e59423e9a18 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -736,18 +736,18 @@ class Scheduler(SchedulerInterface): if num_new_tokens == 0 or not request.has_encoder_inputs: return [], num_new_tokens, encoder_compute_budget encoder_inputs_to_schedule: list[int] = [] - mm_positions = request.mm_positions - assert mm_positions is not None - assert len(mm_positions) > 0 + mm_features = request.mm_features + assert mm_features is not None + assert len(mm_features) > 0 # NOTE: since scheduler operates on the request level (possibly with # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. mm_hashes_to_schedule = set() num_tokens_to_schedule = 0 - for i, pos_info in enumerate(mm_positions): - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length + for i, mm_feature in enumerate(mm_features): + start_pos = mm_feature.mm_position.offset + num_encoder_tokens = mm_feature.mm_position.length # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -778,7 +778,7 @@ class Scheduler(SchedulerInterface): if not self.is_encoder_decoder: # We are not using the encoder cache for encoder-decoder models, # yet. - if request.mm_hashes[i] in mm_hashes_to_schedule: + if request.mm_features[i].identifier in mm_hashes_to_schedule: # The same encoder input has already been scheduled in the # current step. continue @@ -820,7 +820,7 @@ class Scheduler(SchedulerInterface): num_tokens_to_schedule += num_encoder_tokens encoder_compute_budget -= num_encoder_tokens - mm_hashes_to_schedule.add(request.mm_hashes[i]) + mm_hashes_to_schedule.add(request.mm_features[i].identifier) encoder_inputs_to_schedule.append(i) return ( @@ -1048,9 +1048,9 @@ class Scheduler(SchedulerInterface): # Here, we use list(set) to avoid modifying the set while iterating # over it. for input_id in list(cached_encoder_input_ids): - mm_positions = request.mm_positions[input_id] - start_pos = mm_positions.offset - num_tokens = mm_positions.length + mm_feature = request.mm_features[input_id] + start_pos = mm_feature.mm_position.offset + num_tokens = mm_feature.mm_position.length if self.is_encoder_decoder and request.num_computed_tokens > 0: # With Whisper, as soon as we've generated a single token, # we know we're done with the encoder input. Cross Attention diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 64cce3e9efc51..4e3e581235cce 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -91,11 +91,6 @@ class Request: self.mm_features = mm_features or [] self.num_encoder_inputs = len(self.mm_features) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # TODO(sfeng33): Remove these legacy fields after clearing out all - # references in scheduler and model runner - self.mm_positions = [f.mm_position for f in self.mm_features] - self.mm_kwargs = [f.data for f in self.mm_features] - self.mm_hashes = [f.identifier for f in self.mm_features] # Read-only views # Prevent directly appending to these lists since @@ -180,8 +175,8 @@ class Request: return RequestStatus.get_finished_reason(self.status) def get_num_encoder_tokens(self, input_id: int) -> int: - assert input_id < len(self.mm_positions) - num_tokens = self.mm_positions[input_id].length + assert input_id < len(self.mm_features) + num_tokens = self.mm_features[input_id].mm_position.length return num_tokens def record_event( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1cf56656d7adf..339b9937b73f4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,8 +10,7 @@ import torch from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalKwargsItem, - MultiModalKwargsItems, PlaceholderRange) +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values @@ -31,9 +30,7 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_positions: list[PlaceholderRange] - mm_hashes: list[str] + mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -60,7 +57,8 @@ class CachedRequestState: "removed in v0.13. Please use `mm_kwargs` instead.") def mm_inputs(self) -> list[MultiModalKwargsItems]: return [ - MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs + MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features + if f.data is not None ] def get_token_id(self, idx: int) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ebb18e81c38a8..a17c8783d48fa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -555,9 +555,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, - mm_hashes=new_req_data.mm_hashes, + mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -698,7 +696,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False - for mm_item in req_state.mm_kwargs: + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue mm_input = mm_item.get_data() if (t := mm_input.get("image_grid_thw")) is not None: image_grid_thw.append(t.tolist()) @@ -731,7 +732,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_kwargs = list[MultiModalKwargsItem]() for req in scheduler_output.scheduled_new_reqs: - mm_kwargs.extend(req.mm_kwargs) + for feature in req.mm_features: + if feature.data is not None: + mm_kwargs.append(feature.data) # Input all modalities at once mm_kwargs_combined: BatchedTensorInputs = {} @@ -1361,10 +1364,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) return mm_kwargs, mm_hashes_pos @@ -1426,9 +1429,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens - mm_positions = req_state.mm_positions - mm_hashes = req_state.mm_hashes - for i, pos_info in enumerate(mm_positions): + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1451,7 +1453,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) assert start_idx < end_idx - mm_hash = mm_hashes[i] + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 15af7ffac8095..43f12912707f1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -387,9 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, - mm_hashes=new_req_data.mm_hashes, + mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -822,10 +820,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -883,13 +881,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens - mm_positions = req_state.mm_positions - mm_hashes = req_state.mm_hashes # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid # the intrinsic dynamism that `gather_mm_placeholders` introduces. - for i, pos_info in enumerate(mm_positions): + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -904,8 +901,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The encoder output is already processed and stored # in the decoder's KV cache. continue - - mm_hash = mm_hashes[i] + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}."