mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 04:15:01 +08:00
[Core] Simplify mm processing cache (#22457)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
399d2a10e2
commit
8c9da6be22
@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
tokenization_kwargs: Mapping[str, object],
|
tokenization_kwargs: Mapping[str, object],
|
||||||
*,
|
*,
|
||||||
enable_hf_prompt_update: bool,
|
enable_hf_prompt_update: bool,
|
||||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
) -> tuple[list[int], BatchFeature, bool]:
|
||||||
"""
|
"""
|
||||||
Qwen2.5-Omni reimplements this function to handle text only.
|
Qwen2.5-Omni reimplements this function to handle text only.
|
||||||
"""
|
"""
|
||||||
@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
else:
|
else:
|
||||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
|
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
|
||||||
|
|
||||||
mm_kwargs = self._apply_hf_processor_mm_only(
|
mm_processed_data = self._apply_hf_processor_mm_only(
|
||||||
mm_items=mm_items,
|
mm_items=mm_items,
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, mm_kwargs, False
|
return prompt_ids, mm_processed_data, False
|
||||||
|
|
||||||
def _apply_hf_processor_mm_only(
|
def _apply_hf_processor_mm_only(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
tokenization_kwargs: Mapping[str, object],
|
tokenization_kwargs: Mapping[str, object],
|
||||||
) -> MultiModalKwargs:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
|
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
|
||||||
"""
|
"""
|
||||||
@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
assert "audio" in mm_counts
|
assert "audio" in mm_counts
|
||||||
mm_counts["audio"] -= mm_counts["video"]
|
mm_counts["audio"] -= mm_counts["video"]
|
||||||
|
|
||||||
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
|
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
|
||||||
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
|
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
|
||||||
mm_items=mm_items,
|
mm_items=mm_items,
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return mm_kwargs
|
return mm_processed_data
|
||||||
|
|
||||||
def _validate_mm_placeholders(
|
def _validate_mm_placeholders(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -22,7 +22,8 @@ from typing import Literal, Optional, Union
|
|||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
|
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
|
||||||
|
PreTrainedModel)
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
tokenization_kwargs: Mapping[str, object],
|
tokenization_kwargs: Mapping[str, object],
|
||||||
):
|
) -> tuple[list[int], BatchFeature, bool]:
|
||||||
"""
|
"""
|
||||||
Apply the HF processor on the prompt text and multi-modal data
|
Apply the HF processor on the prompt text and multi-modal data
|
||||||
together.
|
together.
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||||
encode_tokens)
|
encode_tokens)
|
||||||
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
|
from vllm.utils import flatten_2d_lists, full_groupby
|
||||||
|
|
||||||
from .cache import MultiModalCache
|
from .cache import MultiModalCache
|
||||||
from .hasher import MultiModalHasher
|
from .hasher import MultiModalHasher
|
||||||
@ -887,120 +887,19 @@ def find_mm_placeholders(
|
|||||||
return dict(full_groupby_modality(it))
|
return dict(full_groupby_modality(it))
|
||||||
|
|
||||||
|
|
||||||
class ProcessingCacheOptionalItem(NamedTuple):
|
|
||||||
key: str
|
|
||||||
value: Optional[MultiModalKwargsItem]
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessingCacheItem(NamedTuple):
|
|
||||||
key: str
|
|
||||||
value: MultiModalKwargsItem
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessingCache(MultiModalCache):
|
class ProcessingCache(MultiModalCache):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, capacity_gb: float) -> None:
|
||||||
self,
|
|
||||||
capacity_gb: float,
|
|
||||||
*,
|
|
||||||
debug_cache_hit_ratio_steps: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
|
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
|
||||||
self.debug_cache_hits = 0
|
|
||||||
self.debug_cache_total = 0
|
|
||||||
|
|
||||||
self._cache = self.get_lru_cache(
|
self.get = self._cache.get
|
||||||
capacity_gb,
|
self.put = self._cache.put
|
||||||
MultiModalKwargsItem,
|
self.reset = self._cache.clear
|
||||||
debug=bool(debug_cache_hit_ratio_steps),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _maybe_log_cache_stats(self) -> None:
|
|
||||||
steps = self.debug_cache_hit_ratio_steps
|
|
||||||
if not steps:
|
|
||||||
return
|
|
||||||
|
|
||||||
total = self.debug_cache_total
|
_CacheItemOrHash = Union[MultiModalKwargsItem, str]
|
||||||
if total > 0 and total % steps == 0:
|
|
||||||
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
|
||||||
self.debug_cache_hits / total)
|
|
||||||
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
|
|
||||||
self._cache.currsize / GiB_bytes,
|
|
||||||
self._cache.maxsize / GiB_bytes)
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
modality: str,
|
|
||||||
input_item: object,
|
|
||||||
input_kwargs: Mapping[str, object],
|
|
||||||
) -> Optional[MultiModalKwargsItem]:
|
|
||||||
"""
|
|
||||||
Get a processed multi-modal item from the cache
|
|
||||||
according to its dependencies, including:
|
|
||||||
|
|
||||||
- The model ID
|
|
||||||
- The modality of the item
|
|
||||||
- The original data item passed to the HF processor
|
|
||||||
- The configuration options of the HF processor
|
|
||||||
"""
|
|
||||||
self._maybe_log_cache_stats()
|
|
||||||
|
|
||||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
|
||||||
**{modality: input_item},
|
|
||||||
**input_kwargs)
|
|
||||||
|
|
||||||
if self.debug_cache_hit_ratio_steps:
|
|
||||||
if cache_key in self._cache:
|
|
||||||
self.debug_cache_hits += 1
|
|
||||||
|
|
||||||
self.debug_cache_total += 1
|
|
||||||
|
|
||||||
return self._cache.get(cache_key)
|
|
||||||
|
|
||||||
def get_item(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
modality: str,
|
|
||||||
input_item: object,
|
|
||||||
input_kwargs: Mapping[str, object],
|
|
||||||
) -> ProcessingCacheOptionalItem:
|
|
||||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
|
||||||
**{modality: input_item},
|
|
||||||
**input_kwargs)
|
|
||||||
|
|
||||||
return ProcessingCacheOptionalItem(
|
|
||||||
key=cache_key,
|
|
||||||
value=self._cache.get(cache_key),
|
|
||||||
)
|
|
||||||
|
|
||||||
def put(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
modality: str,
|
|
||||||
input_item: object,
|
|
||||||
input_kwargs: Mapping[str, object],
|
|
||||||
output_kwargs: MultiModalKwargsItem,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Put a processed multi-modal item into the cache
|
|
||||||
according to its dependencies
|
|
||||||
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
|
|
||||||
"""
|
|
||||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
|
||||||
**{modality: input_item},
|
|
||||||
**input_kwargs)
|
|
||||||
self._cache[cache_key] = output_kwargs
|
|
||||||
|
|
||||||
def put_item(self, item: ProcessingCacheItem) -> None:
|
|
||||||
self._cache[item.key] = item.value
|
|
||||||
|
|
||||||
def reset(self) -> bool:
|
|
||||||
self._cache.clear()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class BaseProcessingInfo:
|
class BaseProcessingInfo:
|
||||||
@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
tokenization_kwargs: Mapping[str, object],
|
tokenization_kwargs: Mapping[str, object],
|
||||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
) -> tuple[list[int], "BatchFeature", bool]:
|
||||||
"""
|
"""
|
||||||
Apply the HF processor on the prompt text and multi-modal data
|
Apply the HF processor on the prompt text and multi-modal data
|
||||||
together.
|
together.
|
||||||
@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
|
|
||||||
prompt_ids, = processed_data.pop("input_ids").tolist()
|
prompt_ids, = processed_data.pop("input_ids").tolist()
|
||||||
|
|
||||||
mm_kwargs = MultiModalKwargs.from_hf_inputs(
|
|
||||||
processed_data,
|
|
||||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
is_update_applied = self._hf_processor_applies_updates(
|
is_update_applied = self._hf_processor_applies_updates(
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
mm_items=mm_items,
|
mm_items=mm_items,
|
||||||
@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, mm_kwargs, is_update_applied
|
return prompt_ids, processed_data, is_update_applied
|
||||||
|
|
||||||
def _apply_hf_processor_text_only(
|
def _apply_hf_processor_text_only(
|
||||||
self, prompt_text: str,
|
self,
|
||||||
tokenization_kwargs: Mapping[str, object]) -> list[int]:
|
prompt_text: str,
|
||||||
|
tokenization_kwargs: Mapping[str, object],
|
||||||
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Apply the HF processor on the prompt text only.
|
Apply the HF processor on the prompt text only.
|
||||||
|
|
||||||
@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
tokenization_kwargs: Mapping[str, object],
|
tokenization_kwargs: Mapping[str, object],
|
||||||
) -> MultiModalKwargs:
|
) -> "BatchFeature":
|
||||||
"""
|
"""
|
||||||
Apply the HF processor on the multi-modal data only.
|
Apply the HF processor on the multi-modal data only.
|
||||||
|
|
||||||
@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
"""
|
"""
|
||||||
mm_counts = mm_items.get_all_counts()
|
mm_counts = mm_items.get_all_counts()
|
||||||
|
|
||||||
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
|
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
|
||||||
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
|
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
|
||||||
mm_items=mm_items,
|
mm_items=mm_items,
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return mm_kwargs
|
return mm_processed_data
|
||||||
|
|
||||||
def _apply_hf_processor_main(
|
def _apply_hf_processor_main(
|
||||||
self,
|
self,
|
||||||
@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
tokenization_kwargs: Mapping[str, object],
|
tokenization_kwargs: Mapping[str, object],
|
||||||
*,
|
*,
|
||||||
enable_hf_prompt_update: bool,
|
enable_hf_prompt_update: bool,
|
||||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
) -> tuple[list[int], "BatchFeature", bool]:
|
||||||
"""
|
"""
|
||||||
Apply the HF processor on the prompt text and multi-modal data.
|
Apply the HF processor on the prompt text and multi-modal data.
|
||||||
|
|
||||||
@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
else:
|
else:
|
||||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
|
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
|
||||||
|
|
||||||
mm_kwargs = self._apply_hf_processor_mm_only(
|
mm_processed_data = self._apply_hf_processor_mm_only(
|
||||||
mm_items=mm_items,
|
mm_items=mm_items,
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, mm_kwargs, False
|
return prompt_ids, mm_processed_data, False
|
||||||
|
|
||||||
def _get_cache_missing_items(
|
def _get_cache_missing_items(
|
||||||
self,
|
self,
|
||||||
cache: ProcessingCache,
|
cache: ProcessingCache,
|
||||||
mm_data_items: MultiModalDataItems,
|
mm_data_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
mm_hashes: MultiModalHashes,
|
||||||
tokenization_kwargs: Mapping[str, object],
|
) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
|
||||||
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
|
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
|
||||||
str, list[object]]]:
|
modality: [(h if (v := cache.get(h)) is None else v)
|
||||||
model_id = self.info.model_id
|
for h in hashes]
|
||||||
|
for modality, hashes in mm_hashes.items()
|
||||||
mm_cache_items = {
|
|
||||||
modality: [
|
|
||||||
cache.get_item(
|
|
||||||
model_id, modality, item,
|
|
||||||
dict(**hf_processor_mm_kwargs, **tokenization_kwargs))
|
|
||||||
for item in items
|
|
||||||
]
|
|
||||||
for modality, items in mm_data_items.items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mm_missing_idxs = {
|
mm_missing_idxs = {
|
||||||
modality: [
|
modality: [
|
||||||
idx for idx, item in enumerate(cache_items)
|
idx for idx, item_or_hash in enumerate(items_or_hashes)
|
||||||
if item.value is None
|
if isinstance(item_or_hash, str)
|
||||||
]
|
]
|
||||||
for modality, cache_items in mm_cache_items.items()
|
for modality, items_or_hashes in mm_cache_items_or_hashes.items()
|
||||||
}
|
}
|
||||||
mm_missing_data = {
|
mm_missing_data = {
|
||||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||||
for modality, idxs in mm_missing_idxs.items()
|
for modality, idxs in mm_missing_idxs.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
return mm_cache_items, mm_missing_data
|
return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
|
||||||
|
|
||||||
def _hash_mm_items(
|
def _hash_mm_items(
|
||||||
self, mm_items: MultiModalDataItems,
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes:
|
tokenization_kwargs: Mapping[str, object],
|
||||||
|
) -> MultiModalHashes:
|
||||||
"""Create MM hashes to be returned (only used in V1)."""
|
"""Create MM hashes to be returned (only used in V1)."""
|
||||||
model_id = self.info.model_id
|
model_id = self.info.model_id
|
||||||
|
|
||||||
@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
def _merge_mm_kwargs(
|
def _merge_mm_kwargs(
|
||||||
self,
|
self,
|
||||||
cache: ProcessingCache,
|
cache: ProcessingCache,
|
||||||
mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]],
|
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
|
||||||
mm_missing_data: dict[str, list[object]],
|
|
||||||
mm_missing_kwargs: MultiModalKwargs,
|
mm_missing_kwargs: MultiModalKwargs,
|
||||||
) -> dict[str, list[ProcessingCacheItem]]:
|
) -> dict[str, list[MultiModalKwargsItem]]:
|
||||||
mm_missing_next_idx = {modality: 0 for modality in mm_missing_data}
|
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
|
||||||
|
|
||||||
merged_items = defaultdict[str, list[ProcessingCacheItem]](list)
|
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
|
||||||
for modality, cache_items in mm_cache_items.items():
|
for modality, items_or_hashes in mm_cache_items_or_hashes.items():
|
||||||
for cache_item in cache_items:
|
for item_or_hash in items_or_hashes:
|
||||||
if cache_item.value is None:
|
if isinstance(item_or_hash, str):
|
||||||
kw_item = mm_missing_kwargs.get_item(
|
kw_item = mm_missing_kwargs.get_item(
|
||||||
modality,
|
modality,
|
||||||
mm_missing_next_idx[modality],
|
mm_missing_next_idx[modality],
|
||||||
)
|
)
|
||||||
cache_item_new = ProcessingCacheItem(
|
cache.put(item_or_hash, kw_item)
|
||||||
key=cache_item.key,
|
|
||||||
value=kw_item,
|
|
||||||
)
|
|
||||||
|
|
||||||
cache.put_item(cache_item_new)
|
|
||||||
mm_missing_next_idx[modality] += 1
|
mm_missing_next_idx[modality] += 1
|
||||||
else:
|
else:
|
||||||
cache_item_new = ProcessingCacheItem(
|
kw_item = item_or_hash
|
||||||
key=cache_item.key,
|
|
||||||
value=cache_item.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
merged_items[modality].append(cache_item_new)
|
merged_items[modality].append(kw_item)
|
||||||
|
|
||||||
return dict(merged_items)
|
return dict(merged_items)
|
||||||
|
|
||||||
@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||||
(
|
(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_kwargs,
|
mm_processed_data,
|
||||||
is_update_applied,
|
is_update_applied,
|
||||||
) = self._apply_hf_processor_main(
|
) = self._apply_hf_processor_main(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
enable_hf_prompt_update=True,
|
enable_hf_prompt_update=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm_kwargs = MultiModalKwargs.from_hf_inputs(
|
||||||
|
mm_processed_data,
|
||||||
|
self._get_mm_fields_config(mm_processed_data,
|
||||||
|
hf_processor_mm_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||||
tokenization_kwargs)
|
tokenization_kwargs)
|
||||||
if return_mm_hashes else None)
|
if return_mm_hashes else None)
|
||||||
@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
return_mm_hashes=return_mm_hashes,
|
return_mm_hashes=return_mm_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||||
|
tokenization_kwargs)
|
||||||
(
|
(
|
||||||
mm_cache_items,
|
mm_cache_items_or_hashes,
|
||||||
mm_missing_data,
|
mm_missing_data_items,
|
||||||
) = self._get_cache_missing_items(
|
) = self._get_cache_missing_items(
|
||||||
cache=cache,
|
cache=cache,
|
||||||
mm_data_items=mm_data_items,
|
mm_data_items=mm_data_items,
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
mm_hashes=mm_hashes,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm_hashes_to_return = mm_hashes if return_mm_hashes else None
|
||||||
|
|
||||||
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
|
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
|
||||||
# so we can't apply prompt updates until the new multimodal
|
# so we can't apply prompt updates until the new multimodal
|
||||||
# items are combined with the cached multimodal items
|
# items are combined with the cached multimodal items
|
||||||
(
|
(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_missing_kwargs,
|
mm_missing_processed_data,
|
||||||
is_update_applied,
|
is_update_applied,
|
||||||
) = self._apply_hf_processor_main(
|
) = self._apply_hf_processor_main(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
mm_items=self._to_mm_items(mm_missing_data),
|
mm_items=mm_missing_data_items,
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
enable_hf_prompt_update=False,
|
enable_hf_prompt_update=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm_missing_kwargs = MultiModalKwargs.from_hf_inputs(
|
||||||
|
mm_missing_processed_data,
|
||||||
|
self._get_mm_fields_config(mm_missing_processed_data,
|
||||||
|
hf_processor_mm_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
mm_cache_items_merged = self._merge_mm_kwargs(
|
mm_cache_items_merged = self._merge_mm_kwargs(
|
||||||
cache,
|
cache,
|
||||||
mm_cache_items=mm_cache_items,
|
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
|
||||||
mm_missing_data=mm_missing_data,
|
|
||||||
mm_missing_kwargs=mm_missing_kwargs,
|
mm_missing_kwargs=mm_missing_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_kwargs = MultiModalKwargs.from_items([
|
mm_kwargs = MultiModalKwargs.from_items([
|
||||||
item.value for cache_items in mm_cache_items_merged.values()
|
item for cache_items in mm_cache_items_merged.values()
|
||||||
for item in cache_items
|
for item in cache_items
|
||||||
])
|
])
|
||||||
|
|
||||||
mm_hashes = {
|
return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
|
||||||
modality: [item.key for item in cache_items]
|
|
||||||
for modality, cache_items in mm_cache_items_merged.items()
|
|
||||||
} if return_mm_hashes else None
|
|
||||||
|
|
||||||
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
|
|
||||||
|
|
||||||
def _bind_and_group_updates(
|
def _bind_and_group_updates(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -312,25 +312,25 @@ class MsgpackDecoder:
|
|||||||
return arr.view(torch_dtype).view(shape)
|
return arr.view(torch_dtype).view(shape)
|
||||||
|
|
||||||
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
||||||
decoded_items = []
|
return [self._decode_mm_item(v) for v in obj]
|
||||||
for item in obj:
|
|
||||||
elems = []
|
def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem:
|
||||||
for v in item:
|
return MultiModalKwargsItem.from_elems(
|
||||||
v["data"] = self._decode_nested_tensors(v["data"])
|
[self._decode_mm_field_elem(v) for v in obj])
|
||||||
|
|
||||||
|
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem:
|
||||||
|
obj["data"] = self._decode_nested_tensors(obj["data"])
|
||||||
# Reconstruct the field processor using MultiModalFieldConfig
|
# Reconstruct the field processor using MultiModalFieldConfig
|
||||||
factory_meth_name, *field_args = v["field"]
|
factory_meth_name, *field_args = obj["field"]
|
||||||
factory_meth = getattr(MultiModalFieldConfig,
|
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
|
||||||
factory_meth_name)
|
|
||||||
|
|
||||||
# Special case: decode the union "slices" field of
|
# Special case: decode the union "slices" field of
|
||||||
# MultiModalFlatField
|
# MultiModalFlatField
|
||||||
if factory_meth_name == "flat":
|
if factory_meth_name == "flat":
|
||||||
field_args[0] = self._decode_nested_slices(field_args[0])
|
field_args[0] = self._decode_nested_slices(field_args[0])
|
||||||
|
|
||||||
v["field"] = factory_meth(None, *field_args).field
|
obj["field"] = factory_meth(None, *field_args).field
|
||||||
elems.append(MultiModalFieldElem(**v))
|
return MultiModalFieldElem(**obj)
|
||||||
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
|
|
||||||
return decoded_items
|
|
||||||
|
|
||||||
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||||
if isinstance(obj, (int, float)):
|
if isinstance(obj, (int, float)):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user