mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[Refactor] Get prompt updates earlier (#23097)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5a30bd10d8
commit
d3f71f1224
@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, MultiModalHashes,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -291,8 +292,7 @@ class DeepseekVL2MultiModalProcessor(
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
|
||||
bool]:
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
# The processor logic is different for len(images) <= 2 vs > 2
|
||||
# Since the processing cache assumes that the processor output is
|
||||
# invariant of how many images are passed per prompt, we only
|
||||
|
||||
@ -20,8 +20,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItems
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.processing import (MultiModalProcessingInfo,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
@ -480,8 +481,7 @@ class H2OVLMultiModalProcessor(
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
|
||||
bool]:
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
# The processor logic is different for len(images) <= 1 vs > 1
|
||||
# Since the processing cache assumes that the processor output is
|
||||
# invariant of how many images are passed per prompt, we only
|
||||
|
||||
@ -39,7 +39,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, MultiModalHashes,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
@ -309,14 +310,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
|
||||
bool]:
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
mm_hashes,
|
||||
_,
|
||||
) = super()._cached_apply_hf_processor(
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@ -325,7 +320,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
return prompt_ids, mm_kwargs, mm_hashes, True
|
||||
return prompt_ids, mm_info, True
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
|
||||
|
||||
@ -59,6 +59,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalPromptUpdates,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
@ -88,10 +89,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
|
||||
# vllm use `second_per_grid_ts` to compute multimodal rotary embedding
|
||||
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
|
||||
if video_second_per_grid is not None:
|
||||
hf_inputs["second_per_grid_ts"] = video_second_per_grid
|
||||
num_videos = len(video_grid_sizes)
|
||||
|
||||
return dict(
|
||||
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
@ -109,6 +107,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
"video", video_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
|
||||
|
||||
@ -251,6 +250,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
if ('audio_feature_lengths' not in hf_inputs
|
||||
and feature_attention_mask is not None):
|
||||
hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1)
|
||||
|
||||
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
|
||||
if video_second_per_grid is not None:
|
||||
hf_inputs["second_per_grid_ts"] = video_second_per_grid
|
||||
|
||||
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
|
||||
hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video)
|
||||
|
||||
return hf_inputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@ -263,27 +270,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
prompt_ids: list[int],
|
||||
mm_kwargs: MultiModalKwargsItems,
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
is_update_applied: bool,
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
"""
|
||||
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
|
||||
"""
|
||||
unbound_prompt_updates = self._get_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
mm_prompt_updates = self._bind_and_group_updates(
|
||||
unbound_prompt_updates)
|
||||
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||
|
||||
use_audio_in_video = hf_processor_mm_kwargs.get(
|
||||
"use_audio_in_video", False)
|
||||
use_audio_in_video = (all(
|
||||
item["use_audio_in_video"].data
|
||||
for item in mm_kwargs["video"]) if "video" in mm_kwargs else False)
|
||||
|
||||
if is_update_applied:
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
@ -316,9 +316,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
prompt = decode_tokens(tokenizer, prompt_ids)
|
||||
|
||||
if use_audio_in_video:
|
||||
mm_kwargs["use_audio_in_video"] = True
|
||||
|
||||
return prompt_ids, prompt, mm_placeholders
|
||||
|
||||
def _get_prompt_updates(
|
||||
|
||||
@ -35,7 +35,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, MultiModalHashes,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -289,10 +290,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
|
||||
bool]:
|
||||
prompt_ids, mm_kwargs, mm_hashes, _ = super(
|
||||
)._cached_apply_hf_processor(
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@ -301,7 +300,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
return prompt_ids, mm_kwargs, mm_hashes, True
|
||||
return prompt_ids, mm_info, True
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
sampling_rate = self.info.get_hf_processor().sampling_rate
|
||||
|
||||
@ -989,6 +989,18 @@ A collection of hashes with a similar structure as
|
||||
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
||||
"""
|
||||
|
||||
MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]]
|
||||
"""
|
||||
A collection of prompt updates with a similar structure as
|
||||
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalProcessingInfo(NamedTuple):
|
||||
kwargs: MultiModalKwargsItems
|
||||
hashes: Optional[MultiModalHashes]
|
||||
prompt_updates: MultiModalPromptUpdates
|
||||
|
||||
|
||||
class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
"""
|
||||
@ -1363,7 +1375,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
cache: ProcessingCache,
|
||||
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
|
||||
mm_missing_kwargs: MultiModalKwargsItems,
|
||||
) -> dict[str, list[MultiModalKwargsItem]]:
|
||||
) -> MultiModalKwargsItems:
|
||||
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
|
||||
|
||||
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
|
||||
@ -1379,7 +1391,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
merged_items[modality].append(kw_item)
|
||||
|
||||
return dict(merged_items)
|
||||
return MultiModalKwargsItems(merged_items)
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
@ -1389,8 +1401,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
|
||||
bool]:
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
(
|
||||
prompt_ids,
|
||||
mm_processed_data,
|
||||
@ -1413,7 +1424,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
tokenization_kwargs)
|
||||
if return_mm_hashes else None)
|
||||
|
||||
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
|
||||
unbound_prompt_updates = self._get_prompt_updates(
|
||||
mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
mm_prompt_updates = self._bind_and_group_updates(
|
||||
unbound_prompt_updates)
|
||||
|
||||
mm_info = MultiModalProcessingInfo(
|
||||
kwargs=mm_kwargs,
|
||||
hashes=mm_hashes,
|
||||
prompt_updates=mm_prompt_updates,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_info, is_update_applied
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
@ -1423,8 +1448,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
|
||||
bool]:
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
caching the results and reusing cached results.
|
||||
@ -1475,18 +1499,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs),
|
||||
)
|
||||
|
||||
mm_cache_items_merged = self._merge_mm_kwargs(
|
||||
mm_kwargs = self._merge_mm_kwargs(
|
||||
cache,
|
||||
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
|
||||
mm_missing_kwargs=mm_missing_kwargs,
|
||||
)
|
||||
|
||||
mm_kwargs = MultiModalKwargsItems.from_seq([
|
||||
item for cache_items in mm_cache_items_merged.values()
|
||||
for item in cache_items
|
||||
])
|
||||
unbound_prompt_updates = self._get_prompt_updates(
|
||||
mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
mm_prompt_updates = self._bind_and_group_updates(
|
||||
unbound_prompt_updates)
|
||||
|
||||
return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
|
||||
mm_info = MultiModalProcessingInfo(
|
||||
kwargs=mm_kwargs,
|
||||
hashes=mm_hashes_to_return,
|
||||
prompt_updates=mm_prompt_updates,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_info, is_update_applied
|
||||
|
||||
def _bind_and_group_updates(
|
||||
self,
|
||||
@ -1626,19 +1659,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
prompt_ids: list[int],
|
||||
mm_kwargs: MultiModalKwargsItems,
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
is_update_applied: bool,
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
unbound_prompt_updates = self._get_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
mm_prompt_updates = self._bind_and_group_updates(
|
||||
unbound_prompt_updates)
|
||||
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||
|
||||
@ -1694,8 +1719,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
mm_hashes,
|
||||
mm_info,
|
||||
is_update_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
@ -1708,9 +1732,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
# NOTE: tokenization_kwargs are not required to init processor
|
||||
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
prompt_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_kwargs=mm_info.kwargs,
|
||||
mm_prompt_updates=mm_info.prompt_updates,
|
||||
is_update_applied=is_update_applied,
|
||||
)
|
||||
|
||||
@ -1723,8 +1747,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_hashes=mm_hashes,
|
||||
mm_kwargs=mm_info.kwargs,
|
||||
mm_hashes=mm_info.hashes,
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user