mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 01:37:13 +08:00
[Multimodal] Remove legacy multimodal fields in favor of MultiModalFeatureSpec (#24548)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
parent
72fc8aa412
commit
0377802c20
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
|
||||
|
||||
@ -9,8 +10,17 @@ class MockRequest:
|
||||
|
||||
def __init__(self, request_id, mm_hashes, token_counts):
|
||||
self.request_id = request_id
|
||||
self.mm_hashes = mm_hashes
|
||||
self._token_counts = token_counts
|
||||
self.mm_features = []
|
||||
for i, mm_hash in enumerate(mm_hashes):
|
||||
feature = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier=mm_hash,
|
||||
mm_position=PlaceholderRange(offset=0,
|
||||
length=self._token_counts[i]),
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
return self._token_counts[input_id]
|
||||
|
||||
@ -64,9 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
mm_kwargs=[],
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
mm_features=[],
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=PoolingParams(),
|
||||
block_ids=([0], ), # block_ids should be tuple[list[int]]
|
||||
|
||||
@ -203,9 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=_create_sampling_params(),
|
||||
pooling_params=None,
|
||||
mm_kwargs=[],
|
||||
mm_positions=[],
|
||||
mm_hashes=[],
|
||||
mm_features=[],
|
||||
block_ids=([], ),
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
|
||||
@ -118,9 +118,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
mm_kwargs=[],
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
mm_features=[],
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
block_ids=([0], ),
|
||||
|
||||
@ -300,11 +300,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=new_req.mm_hashes)
|
||||
meta.add_request(
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in new_req.mm_features])
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
@ -312,11 +313,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_request(new_req):
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=new_req.mm_hashes)
|
||||
meta.add_request(
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=[f.identifier for f in new_req.mm_features])
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
@ -341,11 +343,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# of the block_ids for the request.
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=request.mm_hashes)
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features])
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
@ -364,10 +367,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size)
|
||||
foldername = self._generate_foldername_debug(torch.tensor(
|
||||
request.prompt_token_ids)[:num_tokens_to_check],
|
||||
request.mm_hashes,
|
||||
create_folder=False)
|
||||
foldername = self._generate_foldername_debug(
|
||||
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check],
|
||||
[f.identifier for f in request.mm_features],
|
||||
create_folder=False)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
|
||||
@ -86,7 +86,7 @@ class EncoderCacheManager:
|
||||
Returns:
|
||||
True if the encoder output for this input is already cached
|
||||
"""
|
||||
mm_hash = request.mm_hashes[input_id]
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
# Not cached at all
|
||||
if mm_hash not in self.cached:
|
||||
return False
|
||||
@ -167,7 +167,7 @@ class EncoderCacheManager:
|
||||
This method assumes can_allocate() returned True for the same input.
|
||||
"""
|
||||
|
||||
mm_hash = request.mm_hashes[input_id]
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
request_id = request.request_id
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
@ -193,8 +193,8 @@ class EncoderCacheManager:
|
||||
"""
|
||||
return {
|
||||
input_id
|
||||
for input_id in range(len(request.mm_hashes))
|
||||
if request.mm_hashes[input_id] in self.cached
|
||||
for input_id in range(len(request.mm_features))
|
||||
if request.mm_features[input_id].identifier in self.cached
|
||||
}
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
@ -208,7 +208,7 @@ class EncoderCacheManager:
|
||||
`can_allocate`).
|
||||
"""
|
||||
req_id = request.request_id
|
||||
mm_hash = request.mm_hashes[input_id]
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
# The mm_hash not in cache or the req_id set is empty
|
||||
if not self.cached.get(mm_hash, None):
|
||||
return
|
||||
|
||||
@ -418,9 +418,9 @@ def need_extra_keys(request: Request) -> bool:
|
||||
# Multimodal requests need to include the MM hash.
|
||||
# LoRA requests need to include the LoRA ID.
|
||||
# Request with provided cache salt need to include the salt.
|
||||
return bool(request.mm_hashes) or (request.lora_request
|
||||
is not None) or (request.cache_salt
|
||||
is not None)
|
||||
return bool(request.mm_features) or (request.lora_request
|
||||
is not None) or (request.cache_salt
|
||||
is not None)
|
||||
|
||||
|
||||
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
||||
@ -442,32 +442,28 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
||||
"""
|
||||
extra_keys: list[Any] = []
|
||||
|
||||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
||||
if not mm_positions:
|
||||
mm_features = request.mm_features
|
||||
if not mm_features:
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
if mm_positions and len(mm_positions) != len(mm_hashes):
|
||||
raise ValueError(
|
||||
"The number of multi-modal positions and hashes must match. This "
|
||||
"is likely because you did not enable MM hashing. "
|
||||
"Please set `mm_processor_cache_gb > 0`.")
|
||||
|
||||
# Note that we assume mm_positions is sorted by offset.
|
||||
# Note that we assume mm_features are sorted by mm_position.offset.
|
||||
# We do not need to check all mm inputs if the start token index is out of
|
||||
# range. This usually happens in the late prefill phase and decoding phase.
|
||||
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
|
||||
last_pos = mm_features[-1].mm_position
|
||||
if last_pos.offset + last_pos.length < start_token_idx:
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
# Support start_mm_idx == -1 to indicate the last mm input.
|
||||
if start_mm_idx < 0:
|
||||
assert -start_mm_idx <= len(mm_positions)
|
||||
start_mm_idx = len(mm_positions) + start_mm_idx
|
||||
assert -start_mm_idx <= len(mm_features)
|
||||
start_mm_idx = len(mm_features) + start_mm_idx
|
||||
|
||||
curr_mm_idx = start_mm_idx
|
||||
while mm_positions and curr_mm_idx < len(mm_positions):
|
||||
assert mm_hashes[curr_mm_idx] is not None
|
||||
offset = mm_positions[curr_mm_idx].offset
|
||||
length = mm_positions[curr_mm_idx].length
|
||||
while mm_features and curr_mm_idx < len(mm_features):
|
||||
mm_feature = mm_features[curr_mm_idx]
|
||||
assert mm_feature.identifier is not None
|
||||
offset = mm_feature.mm_position.offset
|
||||
length = mm_feature.mm_position.length
|
||||
if end_token_idx > offset:
|
||||
if start_token_idx > offset + length:
|
||||
# This block has passed the current mm input.
|
||||
@ -475,7 +471,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
||||
continue
|
||||
|
||||
# The block contains the current mm input.
|
||||
extra_keys.append(mm_hashes[curr_mm_idx])
|
||||
extra_keys.append(mm_feature.identifier)
|
||||
|
||||
if end_token_idx >= offset + length:
|
||||
# If this block contains the end of the current mm input,
|
||||
|
||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
@ -27,9 +27,7 @@ class NewRequestData:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
block_ids: tuple[list[int], ...]
|
||||
@ -45,9 +43,7 @@ class NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
mm_kwargs=request.mm_kwargs,
|
||||
mm_hashes=request.mm_hashes,
|
||||
mm_positions=request.mm_positions,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
@ -59,9 +55,7 @@ class NewRequestData:
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_kwargs={self.mm_kwargs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
@ -73,9 +67,7 @@ class NewRequestData:
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
|
||||
f"mm_kwargs={self.mm_kwargs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
|
||||
@ -736,18 +736,18 @@ class Scheduler(SchedulerInterface):
|
||||
if num_new_tokens == 0 or not request.has_encoder_inputs:
|
||||
return [], num_new_tokens, encoder_compute_budget
|
||||
encoder_inputs_to_schedule: list[int] = []
|
||||
mm_positions = request.mm_positions
|
||||
assert mm_positions is not None
|
||||
assert len(mm_positions) > 0
|
||||
mm_features = request.mm_features
|
||||
assert mm_features is not None
|
||||
assert len(mm_features) > 0
|
||||
|
||||
# NOTE: since scheduler operates on the request level (possibly with
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
mm_hashes_to_schedule = set()
|
||||
num_tokens_to_schedule = 0
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
for i, mm_feature in enumerate(mm_features):
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_encoder_tokens = mm_feature.mm_position.length
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
@ -778,7 +778,7 @@ class Scheduler(SchedulerInterface):
|
||||
if not self.is_encoder_decoder:
|
||||
# We are not using the encoder cache for encoder-decoder models,
|
||||
# yet.
|
||||
if request.mm_hashes[i] in mm_hashes_to_schedule:
|
||||
if request.mm_features[i].identifier in mm_hashes_to_schedule:
|
||||
# The same encoder input has already been scheduled in the
|
||||
# current step.
|
||||
continue
|
||||
@ -820,7 +820,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
mm_hashes_to_schedule.add(request.mm_hashes[i])
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
encoder_inputs_to_schedule.append(i)
|
||||
|
||||
return (
|
||||
@ -1048,9 +1048,9 @@ class Scheduler(SchedulerInterface):
|
||||
# Here, we use list(set) to avoid modifying the set while iterating
|
||||
# over it.
|
||||
for input_id in list(cached_encoder_input_ids):
|
||||
mm_positions = request.mm_positions[input_id]
|
||||
start_pos = mm_positions.offset
|
||||
num_tokens = mm_positions.length
|
||||
mm_feature = request.mm_features[input_id]
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_tokens = mm_feature.mm_position.length
|
||||
if self.is_encoder_decoder and request.num_computed_tokens > 0:
|
||||
# With Whisper, as soon as we've generated a single token,
|
||||
# we know we're done with the encoder input. Cross Attention
|
||||
|
||||
@ -91,11 +91,6 @@ class Request:
|
||||
self.mm_features = mm_features or []
|
||||
self.num_encoder_inputs = len(self.mm_features)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
# TODO(sfeng33): Remove these legacy fields after clearing out all
|
||||
# references in scheduler and model runner
|
||||
self.mm_positions = [f.mm_position for f in self.mm_features]
|
||||
self.mm_kwargs = [f.data for f in self.mm_features]
|
||||
self.mm_hashes = [f.identifier for f in self.mm_features]
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
@ -180,8 +175,8 @@ class Request:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_positions)
|
||||
num_tokens = self.mm_positions[input_id].length
|
||||
assert input_id < len(self.mm_features)
|
||||
num_tokens = self.mm_features[input_id].mm_position.length
|
||||
return num_tokens
|
||||
|
||||
def record_event(
|
||||
|
||||
@ -10,8 +10,7 @@ import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import (MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, PlaceholderRange)
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
@ -31,9 +30,7 @@ class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
mm_hashes: list[str]
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
generator: Optional[torch.Generator]
|
||||
@ -60,7 +57,8 @@ class CachedRequestState:
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
|
||||
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
|
||||
if f.data is not None
|
||||
]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
|
||||
@ -555,9 +555,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_kwargs=new_req_data.mm_kwargs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
mm_hashes=new_req_data.mm_hashes,
|
||||
mm_features=new_req_data.mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
generator=generator,
|
||||
@ -698,7 +696,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
second_per_grid_ts = []
|
||||
audio_feature_lengths = []
|
||||
use_audio_in_video = False
|
||||
for mm_item in req_state.mm_kwargs:
|
||||
for mm_feature in req_state.mm_features:
|
||||
mm_item = mm_feature.data
|
||||
if mm_item is None:
|
||||
continue
|
||||
mm_input = mm_item.get_data()
|
||||
if (t := mm_input.get("image_grid_thw")) is not None:
|
||||
image_grid_thw.append(t.tolist())
|
||||
@ -731,7 +732,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
mm_kwargs = list[MultiModalKwargsItem]()
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
mm_kwargs.extend(req.mm_kwargs)
|
||||
for feature in req.mm_features:
|
||||
if feature.data is not None:
|
||||
mm_kwargs.append(feature.data)
|
||||
|
||||
# Input all modalities at once
|
||||
mm_kwargs_combined: BatchedTensorInputs = {}
|
||||
@ -1361,10 +1364,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_hash = req_state.mm_hashes[mm_input_id]
|
||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||
mm_hashes_pos.append(
|
||||
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||
mm_feature = req_state.mm_features[mm_input_id]
|
||||
mm_hash = mm_feature.identifier
|
||||
mm_kwargs.append(mm_feature.data)
|
||||
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
||||
|
||||
return mm_kwargs, mm_hashes_pos
|
||||
|
||||
@ -1426,9 +1429,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = \
|
||||
req_state.num_computed_tokens + shift_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
mm_hashes = req_state.mm_hashes
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
for mm_feature in req_state.mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
@ -1451,7 +1453,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
|
||||
mm_hash = mm_hashes[i]
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None,\
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
@ -387,9 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_kwargs=new_req_data.mm_kwargs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
mm_hashes=new_req_data.mm_hashes,
|
||||
mm_features=new_req_data.mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
generator=None,
|
||||
@ -822,10 +820,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_hash = req_state.mm_hashes[mm_input_id]
|
||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||
mm_hashes_pos.append(
|
||||
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||
mm_feature = req_state.mm_features[mm_input_id]
|
||||
mm_hash = mm_feature.identifier
|
||||
mm_kwargs.append(mm_feature.data)
|
||||
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
||||
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
@ -883,13 +881,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
mm_hashes = req_state.mm_hashes
|
||||
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
|
||||
# NOTE (NickLucche) here we diverge from logic in other runners, as
|
||||
# we assume to only have whole mm items to process. Hence we avoid
|
||||
# the intrinsic dynamism that `gather_mm_placeholders` introduces.
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
for mm_feature in req_state.mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
@ -904,8 +901,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
mm_hash = mm_hashes[i]
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None,\
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user