mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 05:17:03 +08:00
[VLM] Support caching in merged multi-modal processor (#11396)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5ce4627a7e
commit
101418096f
@ -191,6 +191,7 @@ def linkcode_resolve(domain, info):
|
||||
|
||||
# Mock out external dependencies here, otherwise the autodoc pages may be blank.
|
||||
autodoc_mock_imports = [
|
||||
"blake3",
|
||||
"compressed_tensors",
|
||||
"cpuinfo",
|
||||
"cv2",
|
||||
@ -207,7 +208,7 @@ autodoc_mock_imports = [
|
||||
"tensorizer",
|
||||
"pynvml",
|
||||
"outlines",
|
||||
"xgrammar,"
|
||||
"xgrammar",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"gguf",
|
||||
|
||||
@ -45,31 +45,23 @@ adding_multimodal_plugin
|
||||
### Base Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.NestedTensors
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||
.. automodule:: vllm.multimodal.base
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.MultiModalDataDict
|
||||
```
|
||||
### Input Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalKwargs
|
||||
.. automodule:: vllm.multimodal.inputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Audio Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalPlugin
|
||||
.. automodule:: vllm.multimodal.audio
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
@ -81,3 +73,11 @@ adding_multimodal_plugin
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Video Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal.video
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
@ -755,8 +755,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal
|
||||
```
|
||||
|
||||
```{note}
|
||||
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo ({code}`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
|
||||
and pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
```
|
||||
|
||||
```{note}
|
||||
|
||||
@ -91,5 +91,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 3072
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 765
|
||||
assert embeddings.usage.total_tokens == 765
|
||||
assert embeddings.usage.prompt_tokens == 764
|
||||
assert embeddings.usage.total_tokens == 764
|
||||
|
||||
@ -30,7 +30,7 @@ def get_max_qwen2_vl_image_tokens():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
|
||||
({}, 1225),
|
||||
({}, 16384),
|
||||
({
|
||||
MIN_PIXELS: 64**2,
|
||||
MAX_PIXELS: 512**2
|
||||
|
||||
@ -201,6 +201,7 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
|
||||
num_logprobs=10,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
),
|
||||
"glm4": VLMTestInfo(
|
||||
models=["THUDM/glm-4v-9b"],
|
||||
@ -212,7 +213,7 @@ VLM_TEST_SETTINGS = {
|
||||
dtype="bfloat16",
|
||||
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
|
||||
patch_hf_runner=model_utils.glm_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"h2ovl": VLMTestInfo(
|
||||
models = [
|
||||
@ -261,6 +262,7 @@ VLM_TEST_SETTINGS = {
|
||||
dtype="bfloat16",
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"llava_next": VLMTestInfo(
|
||||
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
|
||||
|
||||
@ -1,12 +1,20 @@
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_placeholders, iter_token_matches,
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
||||
_PlaceholderInfo, find_text_matches,
|
||||
find_token_matches, iter_placeholders,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import full_groupby
|
||||
|
||||
@ -457,6 +465,7 @@ def test_find_replace_tokens(
|
||||
),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_iter_placeholders(
|
||||
repl_by_key,
|
||||
prompt,
|
||||
@ -475,11 +484,199 @@ def test_iter_placeholders(
|
||||
prompt_repls,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3 for key in repl_by_key},
|
||||
))
|
||||
{key: 3
|
||||
for key in repl_by_key},
|
||||
))
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int):
|
||||
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
||||
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
|
||||
return Image.fromarray(arr)
|
||||
|
||||
|
||||
def _rand_video(
|
||||
rng: np.random.RandomState,
|
||||
min_frames: int,
|
||||
max_frames: int,
|
||||
min_wh: int,
|
||||
max_wh: int,
|
||||
):
|
||||
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
|
||||
num_frames = rng.randint(min_frames, max_frames)
|
||||
num_frames = (num_frames // 2) * 2
|
||||
|
||||
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
||||
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def _rand_audio(
|
||||
rng: np.random.RandomState,
|
||||
min_len: int,
|
||||
max_len: int,
|
||||
sr: int,
|
||||
):
|
||||
audio_len = rng.randint(min_len, max_len)
|
||||
return rng.rand(audio_len), sr
|
||||
|
||||
|
||||
def _test_processing_cache_correctness(
|
||||
model_id: str,
|
||||
modalities: set[str],
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
|
||||
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
|
||||
else:
|
||||
hf_overrides = {}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
hf_overrides=hf_overrides,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
)
|
||||
# Ensure that it can fit all of the data
|
||||
cache = ProcessingCache(capacity=1 << 30)
|
||||
|
||||
baseline_processor = processor_factory(ctx, cache=None)
|
||||
cached_processor = processor_factory(ctx, cache=cache)
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
input_to_hit = {
|
||||
"image": Image.new("RGB", size=(128, 128)),
|
||||
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
|
||||
"audio": (np.zeros((512, )), 16000),
|
||||
}
|
||||
input_factory = {
|
||||
"image":
|
||||
partial(_rand_img, rng, min_wh=128, max_wh=256),
|
||||
"video":
|
||||
partial(_rand_video,
|
||||
rng,
|
||||
min_frames=2,
|
||||
max_frames=8,
|
||||
min_wh=128,
|
||||
max_wh=256),
|
||||
"audio":
|
||||
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
|
||||
}
|
||||
input_max_count = {
|
||||
"image": 3,
|
||||
"video": 3,
|
||||
"audio": 3,
|
||||
}
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
mm_data = {
|
||||
k:
|
||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||
for _ in range(rng.randint(input_max_count[k]))]
|
||||
for k in modalities
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text
|
||||
|
||||
# Drop unnecessary keys and test single -> multi conversion
|
||||
if rng.rand() < simplify_rate:
|
||||
for k in list(mm_data.keys()):
|
||||
if not mm_data[k]:
|
||||
del mm_data[k]
|
||||
elif len(mm_data[k]) == 1:
|
||||
mm_data[k] = mm_data[k][0]
|
||||
|
||||
baseline_result = baseline_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_result = cached_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert baseline_result == cached_result, (
|
||||
f"Failed ({batch_idx=}, {mm_data=})")
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
||||
("llava-hf/llava-1.5-7b-hf", {"image"}),
|
||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}),
|
||||
("mistral-community/pixtral-12b", {"image"}),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}),
|
||||
("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}),
|
||||
("fixie-ai/ultravox-v0_3", {"audio"}),
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
# yapf: enable
|
||||
def test_processing_cache_correctness(
|
||||
model_id: str,
|
||||
modalities: set[str],
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
_test_processing_cache_correctness(
|
||||
model_id,
|
||||
modalities,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
||||
("microsoft/Phi-3-vision-128k-instruct", {"image"}),
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
# yapf: enable
|
||||
def test_processing_cache_correctness_phi3v(
|
||||
model_id: str,
|
||||
modalities: set[str],
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
# HACK - this is an attempted workaround for the following bug
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
from transformers import AutoImageProcessor # noqa: F401
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
|
||||
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
_test_processing_cache_correctness(
|
||||
model_id,
|
||||
modalities,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
)
|
||||
|
||||
@ -99,6 +99,9 @@ class InputContext:
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
|
||||
if isinstance(typ, type):
|
||||
merged_kwargs["processor_cls"] = typ
|
||||
|
||||
hf_processor = cached_get_processor(
|
||||
self.model_config.model,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
@ -132,10 +135,13 @@ class InputProcessingContext(InputContext):
|
||||
def call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
inference_kwargs: Mapping[str, object],
|
||||
data: Mapping[str, object],
|
||||
kwargs: Mapping[str, object] = {},
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Call :code:`hf_processor` on the prompt :code:`data`
|
||||
(text, image, audio...) with configurable options :code:`kwargs`.
|
||||
"""
|
||||
assert callable(hf_processor)
|
||||
|
||||
base_kwargs = self.model_config.mm_processor_kwargs
|
||||
@ -144,21 +150,15 @@ class InputProcessingContext(InputContext):
|
||||
|
||||
merged_kwargs = resolve_mm_processor_kwargs(
|
||||
base_kwargs,
|
||||
inference_kwargs,
|
||||
kwargs,
|
||||
hf_processor,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
try:
|
||||
return hf_processor(
|
||||
text=prompt,
|
||||
**processor_data,
|
||||
**merged_kwargs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
||||
except Exception as exc:
|
||||
data = dict(text=prompt, **processor_data)
|
||||
msg = (f"Failed to apply {type(hf_processor).__name__} "
|
||||
f"on data={data} with kwargs={merged_kwargs}")
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from functools import cached_property
|
||||
from types import MethodType
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
@ -7,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin, SiglipVisionConfig)
|
||||
SiglipVisionConfig)
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
@ -21,10 +20,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
return # Already patched
|
||||
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
orig_preprocess = image_processor.preprocess
|
||||
|
||||
def preprocess(__self, *args, **kwargs):
|
||||
hf_inputs = orig_preprocess(*args, **kwargs)
|
||||
hf_inputs["is_pixtral"] = torch.tensor(True)
|
||||
return hf_inputs
|
||||
|
||||
image_processor.preprocess = MethodType(preprocess, image_processor)
|
||||
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
hf_processor = self.ctx.get_hf_processor(
|
||||
(LlavaProcessor, PixtralProcessor))
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
return hf_processor
|
||||
# NOTE: pixel_values=None for MLlavaProcessor
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
if isinstance(self._get_hf_processor(), PixtralProcessor):
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1
|
||||
and isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
@ -200,7 +219,7 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
@ -218,7 +237,6 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@ -379,7 +397,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
@ -390,33 +407,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
assert isinstance(is_pixtral, torch.Tensor)
|
||||
if is_pixtral.any():
|
||||
images = pixel_values
|
||||
|
||||
def flatten_to_3d_tensors(item):
|
||||
if isinstance(item, torch.Tensor):
|
||||
if item.dim() >= 3:
|
||||
return [t for t in item.view(-1, *item.shape[-3:])]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected tensor dimension: {item.dim()}")
|
||||
elif isinstance(item, list):
|
||||
return [
|
||||
t for subitem in item
|
||||
for t in flatten_to_3d_tensors(subitem)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(item)}")
|
||||
|
||||
# Restructure the batched images into a list of lists of images
|
||||
images = flatten_to_3d_tensors(pixel_values)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=images,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
@ -586,19 +576,71 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
try:
|
||||
from mantis.models.mllava import MLlavaProcessor
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"You need to `pip install "
|
||||
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
|
||||
"to use this model") from exc
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
processor = MLlavaProcessor.from_pretrained(
|
||||
self.ctx.model_config.tokenizer)
|
||||
assert isinstance(processor, ProcessorMixin)
|
||||
return processor
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
mm_items = self._get_mm_items(mm_data)
|
||||
mm_item_counts = mm_items.get_item_counts()
|
||||
mm_kwargs = result["mm_kwargs"]
|
||||
|
||||
# We reimplement the functionality of MLlavaProcessor from
|
||||
# https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
def get_replacement_mantis(item_idx: int):
|
||||
return "".join([
|
||||
f"(image {item_idx+1}: <Image>", # 7 tokens
|
||||
"<image>" * max_image_tokens,
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * max_image_tokens,
|
||||
replacement=get_replacement_mantis,
|
||||
)
|
||||
])
|
||||
|
||||
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
|
||||
result["prompt_token_ids"],
|
||||
mantis_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
unbound_orig_repls = self._get_prompt_replacements(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
|
||||
|
||||
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
assert len(all_placeholders) == mm_item_counts.get("image", 0)
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [item.to_range() for item in items]
|
||||
for modality, items in full_groupby_modality(all_placeholders)
|
||||
}
|
||||
|
||||
return MultiModalInputsV2(
|
||||
type="multimodal",
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
|
||||
# To use this model, please use
|
||||
|
||||
@ -12,9 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -32,10 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement,
|
||||
_BoundPromptReplacement,
|
||||
_PlaceholderInfo)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens(
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> int:
|
||||
mm_processor_kwargs = {}
|
||||
hf_processor_mm_kwargs = {}
|
||||
if num_crops:
|
||||
mm_processor_kwargs["num_crops"] = num_crops
|
||||
hf_processor_mm_kwargs["num_crops"] = num_crops
|
||||
|
||||
processor = ctx.get_hf_processor(**mm_processor_kwargs)
|
||||
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
return processor.calc_num_image_tokens_from_image_size(
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
@ -331,39 +335,50 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
input_ids = processed_outputs["input_ids"]
|
||||
assert isinstance(input_ids, torch.Tensor)
|
||||
|
||||
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
||||
# which will cause OverflowError when decoding the prompt_ids.
|
||||
# Therefore, we need to do an early replacement here
|
||||
token_ids = processed_outputs['input_ids']
|
||||
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
|
||||
processed_outputs['input_ids'] = token_ids
|
||||
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
tokenizer = self._get_tokenizer()
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
def get_replacement_phi3v(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
@ -372,21 +387,44 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_phi3v,
|
||||
) for image_token in image_tokens[:max_images]
|
||||
) for image_token in image_tokens[:len(mm_items.images)]
|
||||
]
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||
token_ids=token_ids,
|
||||
prompt_repls=prompt_repls,
|
||||
mm_item_counts=mm_item_counts,
|
||||
)
|
||||
|
||||
# Keep the behavior in line with HF processor
|
||||
if text.startswith("<s> <|image|>"):
|
||||
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
|
||||
token_ids = [token_ids[0], *token_ids[2:]]
|
||||
placeholders = [
|
||||
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
|
||||
for p in placeholders
|
||||
]
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
@ -401,9 +439,28 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(image_tokens[:num_images]),
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <|image|> tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token_id
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
|
||||
|
||||
@ -225,7 +225,7 @@ class VisualAttentionBlock(nn.Module):
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -266,7 +266,7 @@ class TransformerBlock(nn.Module):
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -26,7 +26,7 @@ from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
||||
Qwen2AudioEncoder,
|
||||
Qwen2AudioProcessor)
|
||||
@ -38,10 +38,10 @@ from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
@ -73,7 +73,7 @@ class Qwen2AudioMultiModalProjector(nn.Module):
|
||||
|
||||
|
||||
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
feat_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (feat_lengths - 2) // 2 + 1
|
||||
return feat_lengths, output_lengths
|
||||
@ -88,13 +88,18 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> Qwen2AudioProcessor:
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> Qwen2AudioProcessor:
|
||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().feature_extractor # type: ignore
|
||||
|
||||
def _get_processor_data(
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
@ -102,50 +107,61 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_processor_data(mm_items)
|
||||
return super()._get_hf_mm_data(mm_items)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
|
||||
if audios:
|
||||
processor_data["audios"] = audios
|
||||
mm_data["audios"] = audios
|
||||
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
else:
|
||||
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
||||
pass
|
||||
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
placeholder = hf_config.audio_token_index
|
||||
|
||||
feature_attention_mask = hf_inputs.get("feature_attention_mask")
|
||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||
if feature_attention_mask is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||
_, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1))
|
||||
|
||||
@ -168,14 +184,13 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio_count = mm_counts.get("audio", 0)
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|AUDIO|>" * audio_count,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -22,9 +22,10 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, Type, TypedDict, Union)
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Set, Tuple, Type, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -54,10 +55,11 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@ -229,9 +231,9 @@ class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: Optional[int] = None,
|
||||
num_heads: Optional[int] = None,
|
||||
projection_size: Optional[int] = None,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@ -264,7 +266,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
@ -347,7 +349,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@ -384,7 +386,7 @@ class Qwen2VisionPatchEmbed(nn.Module):
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_chans: int = 3,
|
||||
in_channels: int = 3,
|
||||
embed_dim: int = 1152,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -392,8 +394,8 @@ class Qwen2VisionPatchEmbed(nn.Module):
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||
self.proj = nn.Conv3d(in_chans,
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = nn.Conv3d(in_channels,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
@ -413,7 +415,7 @@ class Qwen2VisionPatchMerger(nn.Module):
|
||||
self,
|
||||
d_model: int,
|
||||
context_dim: int,
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -489,15 +491,15 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
patch_size: int = vision_config.patch_size
|
||||
temporal_patch_size: int = vision_config.temporal_patch_size
|
||||
spatial_merge_size: int = vision_config.spatial_merge_size
|
||||
in_chans: int = vision_config.in_chans
|
||||
hidden_size: int = vision_config.hidden_size
|
||||
embed_dim: int = vision_config.embed_dim
|
||||
depth: int = vision_config.depth
|
||||
num_heads: int = vision_config.num_heads
|
||||
mlp_ratio: float = vision_config.mlp_ratio
|
||||
patch_size = vision_config.patch_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
spatial_merge_size = vision_config.spatial_merge_size
|
||||
in_channels = vision_config.in_channels
|
||||
hidden_size = vision_config.hidden_size
|
||||
embed_dim = vision_config.embed_dim
|
||||
depth = vision_config.depth
|
||||
num_heads = vision_config.num_heads
|
||||
mlp_ratio = vision_config.mlp_ratio
|
||||
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.num_heads = num_heads
|
||||
@ -506,7 +508,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
self.patch_embed = Qwen2VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_chans=in_chans,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
@ -733,8 +735,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
|
||||
or is_list_of(v, list)) else [v]
|
||||
v if (
|
||||
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
|
||||
or is_list_of(v, list)
|
||||
or isinstance(v[0], (np.ndarray, torch.Tensor))
|
||||
and v[0].ndim == 4
|
||||
) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
@ -754,6 +760,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
|
||||
for m, items in self.items()
|
||||
}
|
||||
|
||||
def has_embedding_inputs(self) -> bool:
|
||||
return any(
|
||||
isinstance(items, dict) or any(
|
||||
isinstance(item, torch.Tensor) for item in items)
|
||||
for items in self.values())
|
||||
|
||||
|
||||
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
@ -784,7 +796,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_processor_data(
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
@ -805,7 +817,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
and v[0].ndim == 2):
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
else:
|
||||
elif len(v) > 0:
|
||||
# Map keys to plural form, e.g.: image -> images
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
@ -816,8 +828,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
@ -831,7 +843,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
merge_length = image_processor.merge_size**2
|
||||
|
||||
def get_replacement_qwen2vl(item_idx: int, modality: str):
|
||||
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
|
||||
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
|
||||
assert isinstance(grid_thw, torch.Tensor)
|
||||
|
||||
num_tokens = grid_thw.prod() // merge_length
|
||||
return placeholder[modality] * num_tokens
|
||||
|
||||
@ -844,11 +858,40 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) for modality in ("image", "video")
|
||||
]
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
image_slices = [
|
||||
slice(image_slice_idxs[i], image_slice_idxs[i + 1])
|
||||
for i in range(len(image_grid_thw))
|
||||
]
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
video_slices = [
|
||||
slice(video_slice_idxs[i], video_slice_idxs[i + 1])
|
||||
for i in range(len(video_grid_thw))
|
||||
]
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
|
||||
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat(
|
||||
"video", video_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
@ -869,7 +912,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@ -950,9 +992,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self,
|
||||
mm_input: Union[torch.Tensor,
|
||||
List[torch.Tensor]],
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
@ -962,7 +1002,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim}")
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return torch.concat(list(mm_input))
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
@ -23,10 +23,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
@ -72,11 +73,19 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> ProcessorMixin:
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
hf_processor = self._get_hf_processor()
|
||||
return hf_processor.audio_processor.feature_extractor # type: ignore
|
||||
|
||||
def _get_processor_data(
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
@ -84,33 +93,41 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_processor_data(mm_items)
|
||||
return super()._get_hf_mm_data(mm_items)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
# Text-only input not supported in composite processor
|
||||
if not mm_data:
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
prompt_ids = tokenizer.encode(
|
||||
prompt,
|
||||
add_special_tokens=False, # type: ignore
|
||||
)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
|
||||
if not audios:
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
# Already resampled by _get_processor_data
|
||||
# Already resampled by _get_hf_mm_data
|
||||
assert is_list_of(audios, np.ndarray)
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
@ -119,13 +136,12 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
shared_outputs = {}
|
||||
for audio in audios:
|
||||
# NOTE: Ultravox processor accepts "audio" instead of "audios"
|
||||
item_processor_data = dict(**processor_data, audio=audio)
|
||||
item_processor_data = dict(**mm_data, audio=audio)
|
||||
|
||||
item_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=item_processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=item_processor_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
audio_features.append(item_outputs.pop("audio_values")[0])
|
||||
@ -139,17 +155,28 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_len=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
placeholder = hf_processor.audio_token_replacement # type: ignore
|
||||
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
audio_token_len = hf_inputs["audio_token_len"][item_idx]
|
||||
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
|
||||
return placeholder * audio_token_len
|
||||
|
||||
return [
|
||||
@ -168,14 +195,13 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio_count = mm_counts.get("audio", 0)
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * audio_count,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -297,35 +297,37 @@ class MultiModalPlaceholderMap:
|
||||
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
|
||||
vectors to their corresponding placeholders.
|
||||
|
||||
Consider the following scenarios:
|
||||
Examples:
|
||||
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: |.................................|
|
||||
.. code-block::
|
||||
|
||||
images = [A, B]
|
||||
src_ranges = [(0, 4), (4, 8)]
|
||||
dest_ranges = [(0, 4), (5, 9)]
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: |.................................|
|
||||
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | ..... |
|
||||
images = [A, B]
|
||||
src_ranges = [(0, 4), (4, 8)]
|
||||
dest_ranges = [(0, 4), (5, 9)]
|
||||
|
||||
images = [A, B]
|
||||
src_ranges = [(2, 4), (4, 6)]
|
||||
dest_ranges = [(0, 2), (3, 5)]
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | ..... |
|
||||
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | ......... |
|
||||
images = [A, B]
|
||||
src_ranges = [(2, 4), (4, 6)]
|
||||
dest_ranges = [(0, 2), (3, 5)]
|
||||
|
||||
images = [B]
|
||||
src_ranges = [(0, 4)]
|
||||
dest_ranges = [(0, 4)]
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | ......... |
|
||||
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | .......................|
|
||||
images = [B]
|
||||
src_ranges = [(0, 4)]
|
||||
dest_ranges = [(0, 4)]
|
||||
|
||||
images = []
|
||||
src_ranges = []
|
||||
dest_ranges = []
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | .......................|
|
||||
|
||||
images = []
|
||||
src_ranges = []
|
||||
dest_ranges = []
|
||||
"""
|
||||
seq_mm_data = seq_group.multi_modal_data
|
||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple,
|
||||
TypedDict, TypeVar, Union, cast, final)
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast,
|
||||
final)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from typing_extensions import NotRequired, TypeAlias
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import NotRequired, TypeAlias, assert_never
|
||||
|
||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||
|
||||
@ -44,7 +48,7 @@ item, which can be passed to a HuggingFace :code:`AudioProcessor`.
|
||||
"""
|
||||
# yapf: enable
|
||||
|
||||
MultiModalData: TypeAlias = Union[_T, List[_T]]
|
||||
MultiModalData: TypeAlias = Union[_T, list[_T]]
|
||||
"""
|
||||
Either a single data item, or a list of data items.
|
||||
|
||||
@ -79,13 +83,135 @@ Note:
|
||||
"""
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
"""
|
||||
As :class:`MultiModalDataDict`, but normalized such that each entry
|
||||
corresponds to a list.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = MultiModalDataItems()
|
||||
|
||||
for k, v in data.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if (
|
||||
isinstance(v, torch.Tensor)
|
||||
or is_list_of(v, list)
|
||||
or isinstance(v[0], (np.ndarray, torch.Tensor))
|
||||
and v[0].ndim == 4
|
||||
) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if isinstance(v, (torch.Tensor, list)) else [v]
|
||||
)
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
|
||||
return multi_data
|
||||
|
||||
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
|
||||
# `self.images` doesn't update this dictionary, which may be confusing
|
||||
# We annotate the getter methods as `Sequence` to prevent others from
|
||||
# trying to update the list in this way
|
||||
@property
|
||||
def images(self) -> Sequence[ImageItem]:
|
||||
return self.get("image", [])
|
||||
|
||||
@property
|
||||
def videos(self) -> Sequence[VideoItem]:
|
||||
return self.get("video", [])
|
||||
|
||||
@property
|
||||
def audios(self) -> Sequence[AudioItem]:
|
||||
return self.get("audio", [])
|
||||
|
||||
def get_item_counts(self) -> Mapping[str, int]:
|
||||
return {m: len(items) for m, items in self.items()}
|
||||
|
||||
def has_embedding_inputs(self) -> bool:
|
||||
return any(
|
||||
any(isinstance(item, torch.Tensor) for item in items)
|
||||
for items in self.values())
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.images[item_idx]
|
||||
|
||||
if isinstance(image, Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
def get_audio_with_sr(
|
||||
self,
|
||||
item_idx: int,
|
||||
*,
|
||||
default_sr: float,
|
||||
) -> tuple[np.ndarray, float]:
|
||||
audio = self.audios[item_idx]
|
||||
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), default_sr
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, default_sr
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
|
||||
"""
|
||||
If :code:`drop_sr=True`, the audio items in this dictionary are updated
|
||||
to be NumPy arrays which implicitly means that their sampling rate is
|
||||
the same as the model's expected sampling rate; otherwise, they remain
|
||||
as :code:`(audio, new_sr)` tuples.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from .audio import resample_audio
|
||||
|
||||
if not self.audios:
|
||||
return
|
||||
|
||||
new_audios = []
|
||||
for item_idx in range(len(self.audios)):
|
||||
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
|
||||
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
|
||||
|
||||
new_audios.append(audio if drop_sr else (audio, new_sr))
|
||||
|
||||
self["audio"] = new_audios
|
||||
|
||||
|
||||
class PlaceholderRange(TypedDict):
|
||||
"""
|
||||
Placeholder location information for multi-modal data.
|
||||
|
||||
For example:
|
||||
Prompt: AAAA BBBB What is in these images?
|
||||
Example:
|
||||
|
||||
Prompt: :code:`AAAA BBBB What is in these images?`
|
||||
|
||||
Images A and B will have:
|
||||
|
||||
.. code-block::
|
||||
|
||||
A: { "offset": 0, "length": 4 }
|
||||
B: { "offset": 5, "length": 4 }
|
||||
"""
|
||||
@ -97,25 +223,256 @@ class PlaceholderRange(TypedDict):
|
||||
"""The length of the placeholder."""
|
||||
|
||||
|
||||
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor,
|
||||
Tuple[torch.Tensor, ...]]
|
||||
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
|
||||
tuple[torch.Tensor, ...]]
|
||||
"""
|
||||
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
"""
|
||||
|
||||
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
|
||||
|
||||
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
"""Equality check between :data:`NestedTensors` objects."""
|
||||
if isinstance(a, torch.Tensor):
|
||||
return isinstance(b, torch.Tensor) and bool((a == b).all().item())
|
||||
elif isinstance(b, torch.Tensor):
|
||||
return isinstance(a, torch.Tensor) and bool((b == a).all().item())
|
||||
|
||||
if isinstance(a, list):
|
||||
return (isinstance(b, list)
|
||||
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
|
||||
if isinstance(b, list):
|
||||
return (isinstance(a, list)
|
||||
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))
|
||||
|
||||
# Both a and b are scalars
|
||||
return a == b
|
||||
|
||||
|
||||
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
:meth:`MultiModalKwargs.batch`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFieldItem:
|
||||
"""
|
||||
Contains metadata and data in :class:`MultiModalKwargs`
|
||||
corresponding to a data item in :class:`MultiModalDataItems`.
|
||||
"""
|
||||
field: "BaseMultiModalField"
|
||||
data: NestedTensors
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
return (self.field == other.field
|
||||
and nested_tensors_equal(self.data, other.data))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseMultiModalField(ABC):
|
||||
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
|
||||
key: str
|
||||
modality: str
|
||||
|
||||
@abstractmethod
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
|
||||
return MultiModalFieldItem(self, data)
|
||||
|
||||
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
|
||||
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
|
||||
fields = [item.field for item in batch]
|
||||
if len(set(fields)) > 1:
|
||||
raise ValueError(f"Cannot merge different {fields=}")
|
||||
|
||||
data = self._reduce_data([item.data for item in batch])
|
||||
|
||||
return self._build_item(data)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalBatchedField(BaseMultiModalField):
|
||||
"""
|
||||
A :class:`BaseMultiModalField` implementation where an item is obtained by
|
||||
directly indexing into the first dimension of the underlying data.
|
||||
"""
|
||||
|
||||
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
|
||||
return [self._build_item(item) for item in batch]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
first_shape = batch[0].shape
|
||||
if all(item.shape == first_shape for item in batch):
|
||||
return torch.stack(batch)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFlatField(BaseMultiModalField):
|
||||
"""
|
||||
A :class:`BaseMultiModalField` implementation where an item is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
"""
|
||||
|
||||
def build_items(
|
||||
self,
|
||||
batch: NestedTensors,
|
||||
slices: Sequence[slice],
|
||||
) -> list[MultiModalFieldItem]:
|
||||
return [self._build_item(batch[slice_]) for slice_ in slices]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
first_shape = batch[0].shape
|
||||
if all(item.shape[1:] == first_shape[1:] for item in batch):
|
||||
return torch.concat(batch)
|
||||
|
||||
return [elem for item in batch for elem in item]
|
||||
|
||||
|
||||
class MultiModalFieldConfig:
|
||||
|
||||
@staticmethod
|
||||
def batched(modality: str):
|
||||
return MultiModalFieldConfig(
|
||||
field_cls=MultiModalBatchedField,
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat(modality: str, slices: Sequence[slice]):
|
||||
return MultiModalFieldConfig(
|
||||
field_cls=MultiModalFlatField,
|
||||
modality=modality,
|
||||
slices=slices,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field_cls: type[BaseMultiModalField],
|
||||
modality: str,
|
||||
**field_config: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._field_cls = field_cls
|
||||
self._modality = modality
|
||||
self._field_config = field_config
|
||||
|
||||
def build_items(
|
||||
self,
|
||||
key: str,
|
||||
batch: NestedTensors,
|
||||
) -> list[MultiModalFieldItem]:
|
||||
field = self._field_cls(key=key, modality=self._modality)
|
||||
return field.build_items(batch, **self._field_config) # type: ignore
|
||||
|
||||
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`.
|
||||
|
||||
The metadata :code:`items_by_key` defines how to split batched keyword
|
||||
arguments corresponding to each data item in :class:`MultiModalDataItems`:
|
||||
|
||||
- For a keyword argument, we can access the :code:`i` th item in the batch
|
||||
via :code:`items_by_key[key][i]`.
|
||||
- We can gather the keyword arguments belonging to a modality by finding
|
||||
the keys with items that belong to that modality, then accessing
|
||||
the :code:`i` th item in the batch for each such key.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# All items belong to the "image" modality
|
||||
items_by_key={
|
||||
"pixel_values": [a, b, c, d], # "image" modality
|
||||
"image_grid_thw": [e, f, g, h], # "image" modality
|
||||
"pixel_values_video": [h, i, j], # "video" modality
|
||||
"video_grid_thw": [k, l, m], # "video" modality
|
||||
}
|
||||
|
||||
- The keyword arguments belonging to the first image are
|
||||
:code:`{"pixel_values": a, "image_grid_thw": e}`.
|
||||
- The keyword arguments belonging to the second video are
|
||||
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_hf_inputs(
|
||||
hf_inputs: BatchFeature,
|
||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||
*,
|
||||
enable_sanity_checks: bool = False,
|
||||
):
|
||||
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
||||
# We assume that those fields are not used in vLLM
|
||||
items_by_key = {
|
||||
key: config.build_items(key, batch)
|
||||
for key, config in config_by_key.items()
|
||||
if (batch := hf_inputs.get(key)) is not None
|
||||
}
|
||||
|
||||
return MultiModalKwargs.from_items_by_key(
|
||||
items_by_key,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_items_by_key(
|
||||
items_by_key: Mapping[str, list[MultiModalFieldItem]],
|
||||
*,
|
||||
enable_sanity_checks: bool = False,
|
||||
) -> "MultiModalKwargs":
|
||||
data = {
|
||||
key: items[0].field.reduce(items).data
|
||||
for key, items in items_by_key.items()
|
||||
}
|
||||
|
||||
return MultiModalKwargs(data,
|
||||
items_by_key=items_by_key,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, NestedTensors],
|
||||
*,
|
||||
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
|
||||
enable_sanity_checks: bool = False,
|
||||
) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
# Shallow copy to avoid footgun in case a defaultdict is passed in
|
||||
self._items_by_key = dict(items_by_key)
|
||||
|
||||
keys_by_modality = defaultdict[str, set[str]](set)
|
||||
for key, items in items_by_key.items():
|
||||
for item in items:
|
||||
keys_by_modality[item.field.modality].add(key)
|
||||
|
||||
self._keys_by_modality = dict(keys_by_modality)
|
||||
|
||||
if enable_sanity_checks:
|
||||
for modality, keys in keys_by_modality.items():
|
||||
items_in_modality = {k: items_by_key[k] for k in keys}
|
||||
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
|
||||
batch_size = next(iter(batch_sizes.values()), 0)
|
||||
assert all(bs == batch_size
|
||||
for bs in batch_sizes.values()), dict(
|
||||
modality=modality,
|
||||
batch_sizes=batch_sizes,
|
||||
items_by_key=items_by_key)
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||
"""
|
||||
@ -139,7 +496,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(List[torch.Tensor], stacked)
|
||||
tensors_ = cast(list[torch.Tensor], stacked)
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
@ -147,7 +504,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
return torch.stack(tensors_)
|
||||
|
||||
@staticmethod
|
||||
def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs:
|
||||
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
|
||||
"""
|
||||
Batch multiple inputs together into a dictionary.
|
||||
|
||||
@ -162,7 +519,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
|
||||
# We need to consider the case where each item in the batch
|
||||
# contains different modalities (i.e. different keys).
|
||||
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
|
||||
item_lists = defaultdict[str, list[NestedTensors]](list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
for k, v in inputs.items():
|
||||
@ -188,6 +545,57 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if self._items_by_key != other._items_by_key:
|
||||
return False
|
||||
|
||||
ks = self.keys()
|
||||
return (ks == other.keys()
|
||||
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
||||
|
||||
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
|
||||
return self._items_by_key[key][item_index]
|
||||
|
||||
def get_items_by_modality(
|
||||
self,
|
||||
modality: str,
|
||||
item_index: int,
|
||||
) -> Mapping[str, MultiModalFieldItem]:
|
||||
"""
|
||||
Get the keyword arguments corresponding to an item identified by
|
||||
its modality and index.
|
||||
"""
|
||||
keys_to_gather = self._keys_by_modality[modality]
|
||||
|
||||
return {
|
||||
key: self.get_item(key, item_index)
|
||||
for key in keys_to_gather if key in self
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_items_by_modality(
|
||||
items_by_modality: Mapping[str, list[Mapping[str,
|
||||
MultiModalFieldItem]]],
|
||||
*,
|
||||
enable_sanity_checks: bool = False,
|
||||
) -> "MultiModalKwargs":
|
||||
"""
|
||||
Construct a new :class:`MultiModalKwargs` from multiple items returned
|
||||
by :meth:`get_fields_by_modality`.
|
||||
"""
|
||||
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
|
||||
for fields in items_by_modality.values():
|
||||
for field in fields:
|
||||
for k, v in field.items():
|
||||
items_by_key[k].append(v)
|
||||
|
||||
return MultiModalKwargs.from_items_by_key(
|
||||
items_by_key,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
|
||||
"""
|
||||
@ -207,16 +615,16 @@ class MultiModalInputsV2(TypedDict):
|
||||
prompt: str
|
||||
"""The processed prompt text."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
prompt_token_ids: list[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
token_type_ids: NotRequired[List[int]]
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the prompt."""
|
||||
|
||||
mm_kwargs: MultiModalKwargs
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_hashes: NotRequired[List[str]]
|
||||
mm_hashes: NotRequired[list[str]]
|
||||
"""The hashes of the multi-modal data."""
|
||||
|
||||
mm_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pickle
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
@ -8,19 +8,18 @@ from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.inputs import DummyData, InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of
|
||||
|
||||
from .audio import resample_audio
|
||||
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
||||
VideoItem)
|
||||
from .inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalFieldItem,
|
||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -201,111 +200,6 @@ class _BoundPromptReplacement:
|
||||
return bound_replacement
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
"""
|
||||
As :class:`MultiModalDataDict`, but normalized such that each entry
|
||||
corresponds to a list.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = MultiModalDataItems()
|
||||
|
||||
for k, v in data.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if (isinstance(v, torch.Tensor)
|
||||
or is_list_of(v, list)) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if isinstance(v, (torch.Tensor, list)) else [v]
|
||||
)
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
|
||||
return multi_data
|
||||
|
||||
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
|
||||
# `self.images` doesn't update this dictionary, which may be confusing
|
||||
# We annotate the getter methods as `Sequence` to prevent others from
|
||||
# trying to update the list in this way
|
||||
@property
|
||||
def images(self) -> Sequence[ImageItem]:
|
||||
return self.get("image", [])
|
||||
|
||||
@property
|
||||
def videos(self) -> Sequence[VideoItem]:
|
||||
return self.get("video", [])
|
||||
|
||||
@property
|
||||
def audios(self) -> Sequence[AudioItem]:
|
||||
return self.get("audio", [])
|
||||
|
||||
def get_item_counts(self) -> Mapping[str, int]:
|
||||
return {m: len(items) for m, items in self.items()}
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.images[item_idx]
|
||||
|
||||
if isinstance(image, Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
def get_audio_with_sr(
|
||||
self,
|
||||
item_idx: int,
|
||||
*,
|
||||
default_sr: float,
|
||||
) -> tuple[np.ndarray, float]:
|
||||
audio = self.audios[item_idx]
|
||||
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), default_sr
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, default_sr
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
|
||||
"""
|
||||
If :code:`drop_sr=True`, the audio items in this dictionary are updated
|
||||
to be NumPy arrays which implicitly means that their sampling rate is
|
||||
the same as the model's expected sampling rate; otherwise, they remain
|
||||
as :code:`(audio, new_sr)` tuples.
|
||||
"""
|
||||
if not self.audios:
|
||||
return
|
||||
|
||||
new_audios = []
|
||||
for item_idx in range(len(self.audios)):
|
||||
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
|
||||
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
|
||||
|
||||
new_audios.append(audio if drop_sr else (audio, new_sr))
|
||||
|
||||
self["audio"] = new_audios
|
||||
|
||||
|
||||
class _TokenMatch(NamedTuple):
|
||||
start_idx: int
|
||||
end_idx: int
|
||||
@ -583,11 +477,124 @@ def iter_placeholders(
|
||||
)
|
||||
|
||||
|
||||
class ProcessorInputs(NamedTuple):
|
||||
"""Keyword arguments to :meth:`BaseMultiModalProcessor`"""
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
|
||||
prompt_text: str
|
||||
mm_data: MultiModalDataDict
|
||||
mm_processor_kwargs: Mapping[str, object]
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ProcessingCache:
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
||||
|
||||
self._cache = LRUCache[str, Mapping[str,
|
||||
MultiModalFieldItem]](capacity)
|
||||
|
||||
def _maybe_log_cache_stats(self) -> None:
|
||||
steps = self.debug_cache_hit_ratio_steps
|
||||
if not steps:
|
||||
return
|
||||
|
||||
cache_stats = self._cache.stat()
|
||||
if cache_stats.total % steps == 0:
|
||||
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
||||
cache_stats.hit_ratio)
|
||||
|
||||
def _serialize_item(self, obj: object) -> bytes:
|
||||
# Simple cases
|
||||
if isinstance(obj, str):
|
||||
return obj.encode("utf-8")
|
||||
if isinstance(obj, bytes):
|
||||
return obj
|
||||
if isinstance(obj, Image):
|
||||
return obj.tobytes()
|
||||
|
||||
# Convertible to NumPy arrays
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.numpy()
|
||||
if isinstance(obj, (int, float)):
|
||||
obj = np.array(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tobytes()
|
||||
|
||||
logger.warning(
|
||||
"No serialization method found for %s. "
|
||||
"Falling back to pickle.", type(obj))
|
||||
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def _item_to_bytes(
|
||||
self,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[tuple[bytes, bytes]]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from self._item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from self._item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
key_bytes = self._serialize_item(key)
|
||||
value_bytes = self._serialize_item(obj)
|
||||
yield key_bytes, value_bytes
|
||||
|
||||
def _hash_kwargs(self, **kwargs: object) -> str:
|
||||
hasher = blake3()
|
||||
|
||||
for k, v in kwargs.items():
|
||||
for k_bytes, v_bytes in self._item_to_bytes(k, v):
|
||||
hasher.update(k_bytes)
|
||||
hasher.update(v_bytes)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
def get(
|
||||
self,
|
||||
model_id: str,
|
||||
modality: str,
|
||||
input_item: object,
|
||||
input_kwargs: Mapping[str, object],
|
||||
) -> Optional[Mapping[str, MultiModalFieldItem]]:
|
||||
"""
|
||||
Get a processed multi-modal item from the cache
|
||||
according to its dependencies, including:
|
||||
|
||||
- The model ID
|
||||
- The modality of the item
|
||||
- The original data item passed to the HF processor
|
||||
- The configuration options of the HF processor
|
||||
"""
|
||||
self._maybe_log_cache_stats()
|
||||
|
||||
cache_key = self._hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
return self._cache.get(cache_key)
|
||||
|
||||
def put(
|
||||
self,
|
||||
model_id: str,
|
||||
modality: str,
|
||||
input_item: object,
|
||||
input_kwargs: Mapping[str, object],
|
||||
output_kwargs: Mapping[str, MultiModalFieldItem],
|
||||
) -> None:
|
||||
"""
|
||||
Put a processed multi-modal item into the cache
|
||||
according to its dependencies (see :meth:`get`).
|
||||
"""
|
||||
cache_key = self._hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
self._cache.put(cache_key, output_kwargs)
|
||||
|
||||
|
||||
class BaseMultiModalProcessor(ABC):
|
||||
@ -595,18 +602,24 @@ class BaseMultiModalProcessor(ABC):
|
||||
Abstract base class to process multi-modal inputs to be used in vLLM.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ctx = ctx
|
||||
self.cache = cache
|
||||
self.enable_sanity_checks = enable_sanity_checks
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
return self.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
"""
|
||||
@ -624,12 +637,21 @@ class BaseMultiModalProcessor(ABC):
|
||||
) -> MultiModalDataItems:
|
||||
return MultiModalDataItems.from_dict(mm_data)
|
||||
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
"""Given the HF-processed data, output the metadata of each field."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
"""
|
||||
Given the original multi-modal items for this modality
|
||||
@ -651,7 +673,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
return list(
|
||||
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
|
||||
|
||||
def _get_processor_data(
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
@ -669,7 +691,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
and v[0].ndim == 2):
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
else:
|
||||
elif len(v) > 0:
|
||||
# Map keys to plural form, e.g.: image -> images
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
@ -679,39 +701,181 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
# Not to be confused with `mm_data` in `self.apply`.
|
||||
# This refers to the data to be passed to HF processor.
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
return self.ctx.call_hf_processor(
|
||||
hf_processor,
|
||||
prompt,
|
||||
processor_data,
|
||||
mm_processor_kwargs,
|
||||
self._get_hf_processor(**mm_kwargs),
|
||||
dict(text=prompt, **mm_data),
|
||||
mm_kwargs,
|
||||
)
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# some mm_processor_kwargs may be used in processor initialization
|
||||
# instead of processor call
|
||||
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text and multi-modal data.
|
||||
"""
|
||||
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
|
||||
processor_data, passthrough_data = self._get_processor_data(mm_items)
|
||||
|
||||
hf_inputs = self._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
processed_data = self._call_hf_processor(
|
||||
prompt=prompt_text,
|
||||
mm_data=processor_data,
|
||||
mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
hf_inputs.update(passthrough_data)
|
||||
processed_data.update(passthrough_data)
|
||||
|
||||
return hf_inputs
|
||||
prompt_ids, = processed_data.pop("input_ids").tolist()
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_hf_inputs(
|
||||
processed_data,
|
||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
||||
enable_sanity_checks=self.enable_sanity_checks,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_kwargs
|
||||
|
||||
def _apply_hf_processor_missing(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_missing_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
):
|
||||
"""
|
||||
Apply the HF processor on the full prompt text, but only on the
|
||||
multi-modal data that are missing from the cache.
|
||||
|
||||
Note: We pass prompt text and multi-modal data into the HF processor
|
||||
in separate calls to avoid HF prompt replacement being done for
|
||||
cached items; instead, we rely on our own prompt replacement logic
|
||||
for the full text.
|
||||
"""
|
||||
mm_missing_counts = mm_missing_data_items.get_item_counts()
|
||||
|
||||
prompt_ids, _ = self._apply_hf_processor(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=MultiModalDataItems({}),
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
# Some HF processors (e.g. Qwen2-VL) expect corresponding
|
||||
# multi-modal tokens to be in the prompt text
|
||||
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)
|
||||
|
||||
_, mm_missing_kwargs = self._apply_hf_processor(
|
||||
prompt_text=dummy_inputs.prompt_text,
|
||||
mm_items=mm_missing_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_missing_kwargs
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
caching the results and reusing cached results.
|
||||
"""
|
||||
cache = self.cache
|
||||
model_id = self.ctx.model_config.model
|
||||
|
||||
if cache is None or mm_data_items.has_embedding_inputs():
|
||||
return self._apply_hf_processor(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
mm_maybe_cached_field_items = {
|
||||
modality: [
|
||||
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_data_items.items()
|
||||
}
|
||||
|
||||
mm_missing_idxs = {
|
||||
modality: [idx for idx, out in enumerate(fields) if out is None]
|
||||
for modality, fields in mm_maybe_cached_field_items.items()
|
||||
}
|
||||
mm_missing_data = {
|
||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||
for modality, idxs in mm_missing_idxs.items()
|
||||
}
|
||||
mm_missing_data_items = self._get_mm_items(mm_missing_data)
|
||||
|
||||
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
|
||||
prompt_text=prompt_text,
|
||||
mm_missing_data_items=mm_missing_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
mm_missing_next_idx = {
|
||||
modality: 0
|
||||
for modality in mm_missing_data_items
|
||||
}
|
||||
|
||||
mm_merged_field_items = dict[str, list[Mapping[str,
|
||||
MultiModalFieldItem]]]()
|
||||
for modality, modal_items_lst in mm_maybe_cached_field_items.items():
|
||||
merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]()
|
||||
|
||||
for idx, modal_items in enumerate(modal_items_lst):
|
||||
if modal_items is None:
|
||||
modal_items = mm_missing_kwargs.get_items_by_modality(
|
||||
modality,
|
||||
mm_missing_next_idx[modality],
|
||||
)
|
||||
|
||||
cache.put(
|
||||
model_id,
|
||||
modality,
|
||||
mm_data_items[modality][idx],
|
||||
hf_processor_mm_kwargs,
|
||||
modal_items,
|
||||
)
|
||||
|
||||
mm_missing_next_idx[modality] += 1
|
||||
|
||||
merged_modal_items_lst.append(modal_items)
|
||||
|
||||
mm_merged_field_items[modality] = merged_modal_items_lst
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_missing_counts = mm_missing_data_items.get_item_counts()
|
||||
assert all(
|
||||
item_count == mm_missing_counts[modality]
|
||||
for modality, item_count in mm_missing_next_idx.items()), dict(
|
||||
mm_missing_next_idx=mm_missing_next_idx,
|
||||
mm_missing_counts=mm_missing_counts)
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_items_by_modality(
|
||||
mm_merged_field_items,
|
||||
enable_sanity_checks=self.enable_sanity_checks,
|
||||
)
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_item_counts = mm_data_items.get_item_counts()
|
||||
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
for item_idx in range(item_count):
|
||||
try:
|
||||
mm_kwargs.get_items_by_modality(modality, item_idx)
|
||||
except Exception as e:
|
||||
# Make it easy to set a breakpoint in the debugger
|
||||
raise e
|
||||
|
||||
return prompt_ids, mm_kwargs
|
||||
|
||||
def _bind_prompt_replacements(
|
||||
self,
|
||||
@ -730,6 +894,10 @@ class BaseMultiModalProcessor(ABC):
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
token_matches = find_token_matches(token_ids, prompt_repls)
|
||||
mm_match_counts = {
|
||||
modality: len(matches)
|
||||
for modality, matches in full_groupby_modality(token_matches)
|
||||
}
|
||||
|
||||
# If the search text does not represent a special token,
|
||||
# it may have different token IDs in the prompt, because
|
||||
@ -742,8 +910,8 @@ class BaseMultiModalProcessor(ABC):
|
||||
# of the search text in the prompt, we instead perform string
|
||||
# replacement on the decoded token IDs, then encode them back.
|
||||
if all(
|
||||
len(matches) >= mm_item_counts[modality]
|
||||
for modality, matches in full_groupby_modality(token_matches)
|
||||
mm_match_counts.get(modality, 0) >= item_count
|
||||
for modality, item_count in mm_item_counts.items()
|
||||
): # yapf: disable
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
@ -775,7 +943,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@ -792,20 +960,24 @@ class BaseMultiModalProcessor(ABC):
|
||||
"""
|
||||
mm_items = self._get_mm_items(mm_data)
|
||||
|
||||
hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
|
||||
mm_processor_kwargs)
|
||||
prompt_ids, = hf_inputs.pop("input_ids").tolist()
|
||||
mm_kwargs = MultiModalKwargs(hf_inputs)
|
||||
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
|
||||
prompt_text,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
|
||||
mm_processor_kwargs)
|
||||
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
|
||||
unbound_prompt_repls = self._get_prompt_replacements(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls)
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
mm_item_counts = mm_items.get_item_counts()
|
||||
all_placeholders = self._find_placeholders(all_prompt_repls,
|
||||
prompt_ids, mm_item_counts)
|
||||
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
|
||||
if all_placeholders:
|
||||
tokenizer = self._get_tokenizer()
|
||||
@ -817,7 +989,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
all_placeholders,
|
||||
) = self._apply_prompt_replacements(
|
||||
prompt_ids,
|
||||
all_prompt_repls,
|
||||
prompt_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
@ -855,23 +1027,29 @@ class BaseMultiModalProcessor(ABC):
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
|
||||
mm_inputs = self.apply(*processor_inputs)
|
||||
mm_inputs = self.apply(
|
||||
prompt_text=processor_inputs.prompt_text,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
total_placeholders_by_modality = dict[str, int]()
|
||||
for modality, placeholders in placeholders_by_modality.items():
|
||||
num_placeholders = sum(item["length"] for item in placeholders)
|
||||
max_tokens = mm_max_tokens[modality]
|
||||
|
||||
if num_placeholders != max_tokens:
|
||||
logger.warning(
|
||||
"The processed dummy data has a total of %d placeholder "
|
||||
"tokens for the '%s' modality, which is not the expected "
|
||||
"%d tokens.", num_placeholders, modality, max_tokens)
|
||||
|
||||
total_placeholders_by_modality[modality] = num_placeholders
|
||||
total_placeholders_by_modality = {
|
||||
modality: sum(item["length"] for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
expected_placeholders_by_modality = {
|
||||
modality: mm_max_tokens[modality]
|
||||
for modality in placeholders_by_modality
|
||||
}
|
||||
if total_placeholders_by_modality != expected_placeholders_by_modality:
|
||||
raise AssertionError(
|
||||
f"The processed dummy data has a total of "
|
||||
f"{total_placeholders_by_modality} placeholder tokens, which "
|
||||
f"is not the expected {expected_placeholders_by_modality} "
|
||||
"tokens.")
|
||||
|
||||
total_len = len(prompt_token_ids)
|
||||
if total_len > seq_len:
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import functools
|
||||
from collections import UserDict
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol,
|
||||
Sequence, Type, TypeVar)
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
@ -15,7 +14,7 @@ from .audio import AudioPlugin
|
||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||
from .image import ImagePlugin
|
||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||
from .processing import BaseMultiModalProcessor
|
||||
from .processing import BaseMultiModalProcessor, ProcessingCache
|
||||
from .video import VideoPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -23,15 +22,22 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TODO: Tune the MM cache size
|
||||
MM_CACHE_SIZE = 256
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
|
||||
BaseMultiModalProcessor]
|
||||
"""
|
||||
Constructs a :class:`MultiModalProcessor` instance from the context.
|
||||
|
||||
The processing metadata should be derived from the context.
|
||||
"""
|
||||
class MultiModalProcessorFactory(Protocol):
|
||||
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
) -> BaseMultiModalProcessor:
|
||||
...
|
||||
|
||||
|
||||
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
|
||||
@ -71,6 +77,8 @@ class MultiModalRegistry:
|
||||
|
||||
self._limits_by_model = _MultiModalLimits()
|
||||
|
||||
self._processing_cache = ProcessingCache(MM_CACHE_SIZE)
|
||||
|
||||
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
||||
"""
|
||||
Register a multi-modal plugin so it can be recognized by vLLM.
|
||||
@ -328,15 +336,18 @@ class MultiModalRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Test whether a multi-modal processor is defined for a specific model.
|
||||
"""
|
||||
def _get_model_cls(self, model_config: "ModelConfig"):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
return model_cls in self._processor_factories
|
||||
return model_cls
|
||||
|
||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Test whether a multi-modal processor is defined for a specific model.
|
||||
"""
|
||||
return self._get_model_cls(model_config) in self._processor_factories
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
@ -346,12 +357,11 @@ class MultiModalRegistry:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
"""
|
||||
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
processor_factory = self._processor_factories[model_cls]
|
||||
|
||||
ctx = InputProcessingContext(model_config, tokenizer)
|
||||
return processor_factory(ctx)
|
||||
cache = (None if model_config.disable_mm_preprocessor_cache else
|
||||
self._processing_cache)
|
||||
|
||||
return processor_factory(ctx, cache=cache)
|
||||
|
||||
@ -1,25 +1,31 @@
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast
|
||||
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
|
||||
def get_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
trust_remote_code: bool = False,
|
||||
processor_cls: type[ProcessorMixin] = ProcessorMixin,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Load a processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
processor_factory = (AutoProcessor
|
||||
if processor_cls == ProcessorMixin else processor_cls)
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
processor = processor_factory.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
|
||||
@ -25,11 +25,11 @@ import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import OrderedDict, UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from collections.abc import Hashable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generator, Generic, Hashable, List, Literal,
|
||||
Dict, Generator, Generic, List, Literal, NamedTuple,
|
||||
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||
from uuid import uuid4
|
||||
|
||||
@ -194,13 +194,29 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class CacheInfo(NamedTuple):
|
||||
hits: int
|
||||
total: int
|
||||
|
||||
@property
|
||||
def hit_ratio(self) -> float:
|
||||
if self.total == 0:
|
||||
return 0
|
||||
|
||||
return self.hits / self.total
|
||||
|
||||
|
||||
class LRUCache(Generic[_K, _V]):
|
||||
"""Note: This class is not thread safe!"""
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.cache = OrderedDict[_K, _V]()
|
||||
self.pinned_items = set[_K]()
|
||||
self.capacity = capacity
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
|
||||
def __contains__(self, key: _K) -> bool:
|
||||
return key in self.cache
|
||||
|
||||
@ -218,6 +234,9 @@ class LRUCache(Generic[_K, _V]):
|
||||
def __delitem__(self, key: _K) -> None:
|
||||
self.pop(key)
|
||||
|
||||
def stat(self) -> CacheInfo:
|
||||
return CacheInfo(hits=self._hits, total=self._total)
|
||||
|
||||
def touch(self, key: _K) -> None:
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
@ -226,8 +245,12 @@ class LRUCache(Generic[_K, _V]):
|
||||
if key in self.cache:
|
||||
value = self.cache[key]
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
self._hits += 1
|
||||
else:
|
||||
value = default
|
||||
|
||||
self._total += 1
|
||||
return value
|
||||
|
||||
def put(self, key: _K, value: _V) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user