[VLM] Support caching in merged multi-modal processor (#11396)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-28 01:22:48 +08:00 committed by GitHub
parent 5ce4627a7e
commit 101418096f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1459 additions and 452 deletions

View File

@ -191,6 +191,7 @@ def linkcode_resolve(domain, info):
# Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [
"blake3",
"compressed_tensors",
"cpuinfo",
"cv2",
@ -207,7 +208,7 @@ autodoc_mock_imports = [
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"xgrammar",
"librosa",
"soundfile",
"gguf",

View File

@ -45,31 +45,23 @@ adding_multimodal_plugin
### Base Classes
```{eval-rst}
.. autodata:: vllm.multimodal.NestedTensors
```
```{eval-rst}
.. autodata:: vllm.multimodal.BatchedTensorInputs
```
```{eval-rst}
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
.. automodule:: vllm.multimodal.base
:members:
:show-inheritance:
```
```{eval-rst}
.. autodata:: vllm.multimodal.MultiModalDataDict
```
### Input Classes
```{eval-rst}
.. autoclass:: vllm.multimodal.MultiModalKwargs
.. automodule:: vllm.multimodal.inputs
:members:
:show-inheritance:
```
### Audio Classes
```{eval-rst}
.. autoclass:: vllm.multimodal.MultiModalPlugin
.. automodule:: vllm.multimodal.audio
:members:
:show-inheritance:
```
@ -81,3 +73,11 @@ adding_multimodal_plugin
:members:
:show-inheritance:
```
### Video Classes
```{eval-rst}
.. automodule:: vllm.multimodal.video
:members:
:show-inheritance:
```

View File

@ -755,8 +755,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal
```
```{note}
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo ({code}`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
and pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```
```{note}

View File

@ -91,5 +91,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 3072
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 765
assert embeddings.usage.total_tokens == 765
assert embeddings.usage.prompt_tokens == 764
assert embeddings.usage.total_tokens == 764

View File

@ -30,7 +30,7 @@ def get_max_qwen2_vl_image_tokens():
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 1225),
({}, 16384),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2

View File

@ -201,6 +201,7 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[large_gpu_mark(min_gb=48)],
),
"glm4": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
@ -212,7 +213,7 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
patch_hf_runner=model_utils.glm_patch_hf_runner,
marks=[large_gpu_mark(min_gb=48)],
marks=[large_gpu_mark(min_gb=32)],
),
"h2ovl": VLMTestInfo(
models = [
@ -261,6 +262,7 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"],

View File

@ -1,12 +1,20 @@
from functools import partial
from typing import cast
import numpy as np
import pytest
from PIL import Image
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
@ -457,6 +465,7 @@ def test_find_replace_tokens(
),
]
)
# yapf: enable
def test_iter_placeholders(
repl_by_key,
prompt,
@ -475,11 +484,199 @@ def test_iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
{key: 3 for key in repl_by_key},
))
{key: 3
for key in repl_by_key},
))
# Only displayed on error
print("result:", result)
# Manually constructed results
assert result == expected
def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int):
w, h = rng.randint(min_wh, max_wh, size=(2, ))
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
return Image.fromarray(arr)
def _rand_video(
rng: np.random.RandomState,
min_frames: int,
max_frames: int,
min_wh: int,
max_wh: int,
):
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
num_frames = rng.randint(min_frames, max_frames)
num_frames = (num_frames // 2) * 2
w, h = rng.randint(min_wh, max_wh, size=(2, ))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
def _rand_audio(
rng: np.random.RandomState,
min_len: int,
max_len: int,
sr: int,
):
audio_len = rng.randint(min_len, max_len)
return rng.rand(audio_len), sr
def _test_processing_cache_correctness(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
else:
hf_overrides = {}
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=True,
seed=0,
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
baseline_processor = processor_factory(ctx, cache=None)
cached_processor = processor_factory(ctx, cache=cache)
rng = np.random.RandomState(0)
input_to_hit = {
"image": Image.new("RGB", size=(128, 128)),
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
"audio": (np.zeros((512, )), 16000),
}
input_factory = {
"image":
partial(_rand_img, rng, min_wh=128, max_wh=256),
"video":
partial(_rand_video,
rng,
min_frames=2,
max_frames=8,
min_wh=128,
max_wh=256),
"audio":
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
}
input_max_count = {
"image": 3,
"video": 3,
"audio": 3,
}
for batch_idx in range(num_batches):
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(input_max_count[k]))]
for k in modalities
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text
# Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate:
for k in list(mm_data.keys()):
if not mm_data[k]:
del mm_data[k]
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]
baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})")
# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("llava-hf/llava-1.5-7b-hf", {"image"}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}),
("mistral-community/pixtral-12b", {"image"}),
("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}),
("fixie-ai/ultravox-v0_3", {"audio"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_cache_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
)
# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("microsoft/Phi-3-vision-128k-instruct", {"image"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
# HACK - this is an attempted workaround for the following bug
# https://github.com/huggingface/transformers/issues/34307
from transformers import AutoImageProcessor # noqa: F401
from transformers import AutoProcessor # noqa: F401
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
_test_processing_cache_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
)

View File

@ -99,6 +99,9 @@ class InputContext:
merged_kwargs = {**base_kwargs, **kwargs}
if isinstance(typ, type):
merged_kwargs["processor_cls"] = typ
hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
@ -132,10 +135,13 @@ class InputProcessingContext(InputContext):
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object],
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
) -> BatchFeature:
"""
Call :code:`hf_processor` on the prompt :code:`data`
(text, image, audio...) with configurable options :code:`kwargs`.
"""
assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs
@ -144,21 +150,15 @@ class InputProcessingContext(InputContext):
merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
kwargs,
hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
)
try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")

View File

@ -1,5 +1,4 @@
from functools import cached_property
from types import MethodType
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)
@ -7,7 +6,7 @@ import torch
import torch.nn as nn
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig)
SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
@ -21,10 +20,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors
from .clip import (CLIPVisionModel, dummy_image_for_clip,
@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext):
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
return # Already patched
image_processor = hf_processor.image_processor # type: ignore
orig_preprocess = image_processor.preprocess
def preprocess(__self, *args, **kwargs):
hf_inputs = orig_preprocess(*args, **kwargs)
hf_inputs["is_pixtral"] = torch.tensor(True)
return hf_inputs
image_processor.preprocess = MethodType(preprocess, image_processor)
hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor(
(LlavaProcessor, PixtralProcessor))
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
return hf_processor
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)
if isinstance(self._get_hf_processor(), PixtralProcessor):
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
and len(pixel_values) == 1
and isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
@ -200,7 +219,7 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
@ -218,7 +237,6 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
@ -379,7 +397,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
@ -390,33 +407,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
assert isinstance(is_pixtral, torch.Tensor)
if is_pixtral.any():
images = pixel_values
def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")
# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
@ -586,19 +576,71 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
"to use this model") from exc
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
processor = MLlavaProcessor.from_pretrained(
self.ctx.model_config.tokenizer)
assert isinstance(processor, ProcessorMixin)
return processor
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
max_image_tokens = get_max_llava_image_tokens(self.ctx)
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
mm_items = self._get_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_kwargs = result["mm_kwargs"]
# We reimplement the functionality of MLlavaProcessor from
# https://github.com/TIGER-AI-Lab/Mantis.git
def get_replacement_mantis(item_idx: int):
return "".join([
f"(image {item_idx+1}: <Image>", # 7 tokens
"<image>" * max_image_tokens,
"</Image>)", # 3 tokens
])
mantis_repls = self._bind_prompt_replacements([
PromptReplacement(
modality="image",
target=[image_token_id] * max_image_tokens,
replacement=get_replacement_mantis,
)
])
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_repls,
mm_item_counts,
)
unbound_orig_repls = self._get_prompt_replacements(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
mm_item_counts)
assert len(all_placeholders) == mm_item_counts.get("image", 0)
mm_placeholders = {
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders)
}
return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)
# To use this model, please use

View File

@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -32,10 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens(
*,
num_crops: Optional[int] = None,
) -> int:
mm_processor_kwargs = {}
hf_processor_mm_kwargs = {}
if num_crops:
mm_processor_kwargs["num_crops"] = num_crops
hf_processor_mm_kwargs["num_crops"] = num_crops
processor = ctx.get_hf_processor(**mm_processor_kwargs)
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
@ -331,39 +335,50 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
input_ids = processed_outputs["input_ids"]
assert isinstance(input_ids, torch.Tensor)
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
tokenizer = self._get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
@ -372,21 +387,44 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
) for image_token in image_tokens[:len(mm_items.images)]
]
def _apply_prompt_replacements(
self,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
prompt_repls=prompt_repls,
mm_item_counts=mm_item_counts,
)
# Keep the behavior in line with HF processor
if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
for p in placeholders
]
return token_ids, text, placeholders
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
@ -401,9 +439,28 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)

View File

@ -225,7 +225,7 @@ class VisualAttentionBlock(nn.Module):
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@ -266,7 +266,7 @@ class TransformerBlock(nn.Module):
layers: int,
heads: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()

View File

@ -26,7 +26,7 @@ from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, ProcessorMixin
from transformers import BatchFeature
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
@ -38,10 +38,10 @@ from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@ -73,7 +73,7 @@ class Qwen2AudioMultiModalProjector(nn.Module):
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
feat_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (feat_lengths - 2) // 2 + 1
return feat_lengths, output_lengths
@ -88,13 +88,18 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def _get_hf_processor(self) -> Qwen2AudioProcessor:
def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore
def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
@ -102,50 +107,61 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_processor_data(mm_items)
return super()._get_hf_mm_data(mm_items)
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
if audios:
processor_data["audios"] = audios
mm_data["audios"] = audios
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass
return super()._call_hf_processor(
hf_processor,
processed_outputs = super()._call_hf_processor(
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index
feature_attention_mask = hf_inputs.get("feature_attention_mask")
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
if feature_attention_mask is None:
audio_output_lengths = []
else:
assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lengths = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))
@ -168,14 +184,13 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_count = mm_counts["audio"]
audio_count = mm_counts.get("audio", 0)
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)

View File

@ -22,9 +22,10 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import cached_property, partial
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, Type, TypedDict, Union)
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -54,10 +55,11 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
@ -229,9 +231,9 @@ class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@ -264,7 +266,7 @@ class Qwen2VisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
rotary_pos_emb: torch.Tensor,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@ -347,7 +349,7 @@ class Qwen2VisionBlock(nn.Module):
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@ -384,7 +386,7 @@ class Qwen2VisionPatchEmbed(nn.Module):
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
@ -392,8 +394,8 @@ class Qwen2VisionPatchEmbed(nn.Module):
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_chans,
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
@ -413,7 +415,7 @@ class Qwen2VisionPatchMerger(nn.Module):
self,
d_model: int,
context_dim: int,
norm_layer: Type[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -489,15 +491,15 @@ class Qwen2VisionTransformer(nn.Module):
) -> None:
super().__init__()
patch_size: int = vision_config.patch_size
temporal_patch_size: int = vision_config.temporal_patch_size
spatial_merge_size: int = vision_config.spatial_merge_size
in_chans: int = vision_config.in_chans
hidden_size: int = vision_config.hidden_size
embed_dim: int = vision_config.embed_dim
depth: int = vision_config.depth
num_heads: int = vision_config.num_heads
mlp_ratio: float = vision_config.mlp_ratio
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
spatial_merge_size = vision_config.spatial_merge_size
in_channels = vision_config.in_channels
hidden_size = vision_config.hidden_size
embed_dim = vision_config.embed_dim
depth = vision_config.depth
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio
self.spatial_merge_size = spatial_merge_size
self.num_heads = num_heads
@ -506,7 +508,7 @@ class Qwen2VisionTransformer(nn.Module):
self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_chans=in_chans,
in_channels=in_channels,
embed_dim=embed_dim,
)
@ -733,8 +735,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)) else [v]
v if (
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
@ -754,6 +760,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
for m, items in self.items()
}
def has_embedding_inputs(self) -> bool:
return any(
isinstance(items, dict) or any(
isinstance(item, torch.Tensor) for item in items)
for items in self.values())
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
@ -784,7 +796,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return hf_processor
def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
@ -805,7 +817,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
@ -816,8 +828,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
@ -831,7 +843,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)
num_tokens = grid_thw.prod() // merge_length
return placeholder[modality] * num_tokens
@ -844,11 +858,40 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
) for modality in ("image", "video")
]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
image_slices = [
slice(image_slice_idxs[i], image_slice_idxs[i + 1])
for i in range(len(image_grid_thw))
]
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
video_slices = [
slice(video_slice_idxs[i], video_slice_idxs[i + 1])
for i in range(len(video_grid_thw))
]
return dict(
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat(
"video", video_slices),
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
image_processor = _get_image_processor(hf_processor)
@ -869,7 +912,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
@ -950,9 +992,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
return quant_config
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
@ -962,7 +1002,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim}")
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)

View File

@ -23,10 +23,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
@ -72,11 +73,19 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> ProcessorMixin:
return self.ctx.get_hf_processor()
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore
def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
@ -84,33 +93,41 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_processor_data(mm_items)
return super()._get_hf_mm_data(mm_items)
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self._get_tokenizer()
prompt_ids = tokenizer.encode(
prompt,
add_special_tokens=False, # type: ignore
)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
if not audios:
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
# Already resampled by _get_processor_data
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs,
@ -119,13 +136,12 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
shared_outputs = {}
for audio in audios:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data = dict(**processor_data, audio=audio)
item_processor_data = dict(**mm_data, audio=audio)
item_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=item_processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)
audio_features.append(item_outputs.pop("audio_values")[0])
@ -139,17 +155,28 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
audio_features=MultiModalFieldConfig.batched("audio"),
audio_token_len=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
return placeholder * audio_token_len
return [
@ -168,14 +195,13 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_count = mm_counts["audio"]
audio_count = mm_counts.get("audio", 0)
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)

View File

@ -297,35 +297,37 @@ class MultiModalPlaceholderMap:
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
vectors to their corresponding placeholders.
Consider the following scenarios:
Examples:
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
.. code-block::
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
images = []
src_ranges = []
dest_ranges = []
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = []
src_ranges = []
dest_ranges = []
"""
seq_mm_data = seq_group.multi_modal_data
seq_mm_placeholders = seq_group.multi_modal_placeholders

View File

@ -1,12 +1,16 @@
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple,
TypedDict, TypeVar, Union, cast, final)
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast,
final)
import numpy as np
import torch
import torch.types
from PIL.Image import Image
from typing_extensions import NotRequired, TypeAlias
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias, assert_never
from vllm.utils import JSONTree, is_list_of, json_map_leaves
@ -44,7 +48,7 @@ item, which can be passed to a HuggingFace :code:`AudioProcessor`.
"""
# yapf: enable
MultiModalData: TypeAlias = Union[_T, List[_T]]
MultiModalData: TypeAlias = Union[_T, list[_T]]
"""
Either a single data item, or a list of data items.
@ -79,13 +83,135 @@ Note:
"""
class ImageSize(NamedTuple):
width: int
height: int
class MultiModalDataItems(UserDict[str, list[Any]]):
"""
As :class:`MultiModalDataDict`, but normalized such that each entry
corresponds to a list.
"""
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, torch.Tensor)
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
# `self.images` doesn't update this dictionary, which may be confusing
# We annotate the getter methods as `Sequence` to prevent others from
# trying to update the list in this way
@property
def images(self) -> Sequence[ImageItem]:
return self.get("image", [])
@property
def videos(self) -> Sequence[VideoItem]:
return self.get("video", [])
@property
def audios(self) -> Sequence[AudioItem]:
return self.get("audio", [])
def get_item_counts(self) -> Mapping[str, int]:
return {m: len(items) for m, items in self.items()}
def has_embedding_inputs(self) -> bool:
return any(
any(isinstance(item, torch.Tensor) for item in items)
for items in self.values())
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.images[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def get_audio_with_sr(
self,
item_idx: int,
*,
default_sr: float,
) -> tuple[np.ndarray, float]:
audio = self.audios[item_idx]
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), default_sr
if isinstance(audio, np.ndarray):
return audio, default_sr
assert_never(audio)
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
"""
If :code:`drop_sr=True`, the audio items in this dictionary are updated
to be NumPy arrays which implicitly means that their sampling rate is
the same as the model's expected sampling rate; otherwise, they remain
as :code:`(audio, new_sr)` tuples.
"""
# Avoid circular import
from .audio import resample_audio
if not self.audios:
return
new_audios = []
for item_idx in range(len(self.audios)):
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
new_audios.append(audio if drop_sr else (audio, new_sr))
self["audio"] = new_audios
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Example:
Prompt: :code:`AAAA BBBB What is in these images?`
Images A and B will have:
.. code-block::
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
@ -97,25 +223,256 @@ class PlaceholderRange(TypedDict):
"""The length of the placeholder."""
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor,
Tuple[torch.Tensor, ...]]
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
tuple[torch.Tensor, ...]]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""Equality check between :data:`NestedTensors` objects."""
if isinstance(a, torch.Tensor):
return isinstance(b, torch.Tensor) and bool((a == b).all().item())
elif isinstance(b, torch.Tensor):
return isinstance(a, torch.Tensor) and bool((b == a).all().item())
if isinstance(a, list):
return (isinstance(b, list)
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
if isinstance(b, list):
return (isinstance(a, list)
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))
# Both a and b are scalars
return a == b
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""
@dataclass(frozen=True)
class MultiModalFieldItem:
"""
Contains metadata and data in :class:`MultiModalKwargs`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
field: "BaseMultiModalField"
data: NestedTensors
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return (self.field == other.field
and nested_tensors_equal(self.data, other.data))
@dataclass(frozen=True)
class BaseMultiModalField(ABC):
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
key: str
modality: str
@abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
return MultiModalFieldItem(self, data)
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
fields = [item.field for item in batch]
if len(set(fields)) > 1:
raise ValueError(f"Cannot merge different {fields=}")
data = self._reduce_data([item.data for item in batch])
return self._build_item(data)
@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
directly indexing into the first dimension of the underlying data.
"""
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
return [self._build_item(item) for item in batch]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape == first_shape for item in batch):
return torch.stack(batch)
return batch
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
slicing along the first dimension of the underlying data.
"""
def build_items(
self,
batch: NestedTensors,
slices: Sequence[slice],
) -> list[MultiModalFieldItem]:
return [self._build_item(batch[slice_]) for slice_ in slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape[1:] == first_shape[1:] for item in batch):
return torch.concat(batch)
return [elem for item in batch for elem in item]
class MultiModalFieldConfig:
@staticmethod
def batched(modality: str):
return MultiModalFieldConfig(
field_cls=MultiModalBatchedField,
modality=modality,
)
@staticmethod
def flat(modality: str, slices: Sequence[slice]):
return MultiModalFieldConfig(
field_cls=MultiModalFlatField,
modality=modality,
slices=slices,
)
def __init__(
self,
field_cls: type[BaseMultiModalField],
modality: str,
**field_config: Any,
) -> None:
super().__init__()
self._field_cls = field_cls
self._modality = modality
self._field_config = field_config
def build_items(
self,
key: str,
batch: NestedTensors,
) -> list[MultiModalFieldItem]:
field = self._field_cls(key=key, modality=self._modality)
return field.build_items(batch, **self._field_config) # type: ignore
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
The metadata :code:`items_by_key` defines how to split batched keyword
arguments corresponding to each data item in :class:`MultiModalDataItems`:
- For a keyword argument, we can access the :code:`i` th item in the batch
via :code:`items_by_key[key][i]`.
- We can gather the keyword arguments belonging to a modality by finding
the keys with items that belong to that modality, then accessing
the :code:`i` th item in the batch for each such key.
Example:
.. code-block:: python
# All items belong to the "image" modality
items_by_key={
"pixel_values": [a, b, c, d], # "image" modality
"image_grid_thw": [e, f, g, h], # "image" modality
"pixel_values_video": [h, i, j], # "video" modality
"video_grid_thw": [k, l, m], # "video" modality
}
- The keyword arguments belonging to the first image are
:code:`{"pixel_values": a, "image_grid_thw": e}`.
- The keyword arguments belonging to the second video are
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
"""
@staticmethod
def from_hf_inputs(
hf_inputs: BatchFeature,
config_by_key: Mapping[str, MultiModalFieldConfig],
*,
enable_sanity_checks: bool = False,
):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
items_by_key = {
key: config.build_items(key, batch)
for key, config in config_by_key.items()
if (batch := hf_inputs.get(key)) is not None
}
return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
@staticmethod
def from_items_by_key(
items_by_key: Mapping[str, list[MultiModalFieldItem]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
data = {
key: items[0].field.reduce(items).data
for key, items in items_by_key.items()
}
return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)
def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
enable_sanity_checks: bool = False,
) -> None:
super().__init__(data)
# Shallow copy to avoid footgun in case a defaultdict is passed in
self._items_by_key = dict(items_by_key)
keys_by_modality = defaultdict[str, set[str]](set)
for key, items in items_by_key.items():
for item in items:
keys_by_modality[item.field.modality].add(key)
self._keys_by_modality = dict(keys_by_modality)
if enable_sanity_checks:
for modality, keys in keys_by_modality.items():
items_in_modality = {k: items_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size
for bs in batch_sizes.values()), dict(
modality=modality,
batch_sizes=batch_sizes,
items_by_key=items_by_key)
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
@ -139,7 +496,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
# Only tensors (not lists) can be stacked.
return stacked
tensors_ = cast(List[torch.Tensor], stacked)
tensors_ = cast(list[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
@ -147,7 +504,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs:
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
@ -162,7 +519,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
# We need to consider the case where each item in the batch
# contains different modalities (i.e. different keys).
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
item_lists = defaultdict[str, list[NestedTensors]](list)
for inputs in inputs_list:
for k, v in inputs.items():
@ -188,6 +545,57 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return cast(BatchedTensorInputs, json_mapped)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if self._items_by_key != other._items_by_key:
return False
ks = self.keys()
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
return self._items_by_key[key][item_index]
def get_items_by_modality(
self,
modality: str,
item_index: int,
) -> Mapping[str, MultiModalFieldItem]:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
keys_to_gather = self._keys_by_modality[modality]
return {
key: self.get_item(key, item_index)
for key in keys_to_gather if key in self
}
@staticmethod
def from_items_by_modality(
items_by_modality: Mapping[str, list[Mapping[str,
MultiModalFieldItem]]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
"""
Construct a new :class:`MultiModalKwargs` from multiple items returned
by :meth:`get_fields_by_modality`.
"""
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
for fields in items_by_modality.values():
for field in fields:
for k, v in field.items():
items_by_key[k].append(v)
return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
@ -207,16 +615,16 @@ class MultiModalInputsV2(TypedDict):
prompt: str
"""The processed prompt text."""
prompt_token_ids: List[int]
prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens."""
token_type_ids: NotRequired[List[int]]
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[List[str]]
mm_hashes: NotRequired[list[str]]
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict

View File

@ -1,6 +1,6 @@
import pickle
import re
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import lru_cache
@ -8,19 +8,18 @@ from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np
import torch
from blake3 import blake3
from PIL.Image import Image
from transformers import BatchFeature, ProcessorMixin
from typing_extensions import assert_never
from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)
from .inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalFieldItem,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange)
logger = init_logger(__name__)
@ -201,111 +200,6 @@ class _BoundPromptReplacement:
return bound_replacement
class ImageSize(NamedTuple):
width: int
height: int
class MultiModalDataItems(UserDict[str, list[Any]]):
"""
As :class:`MultiModalDataDict`, but normalized such that each entry
corresponds to a list.
"""
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (isinstance(v, torch.Tensor)
or is_list_of(v, list)) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
# `self.images` doesn't update this dictionary, which may be confusing
# We annotate the getter methods as `Sequence` to prevent others from
# trying to update the list in this way
@property
def images(self) -> Sequence[ImageItem]:
return self.get("image", [])
@property
def videos(self) -> Sequence[VideoItem]:
return self.get("video", [])
@property
def audios(self) -> Sequence[AudioItem]:
return self.get("audio", [])
def get_item_counts(self) -> Mapping[str, int]:
return {m: len(items) for m, items in self.items()}
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.images[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def get_audio_with_sr(
self,
item_idx: int,
*,
default_sr: float,
) -> tuple[np.ndarray, float]:
audio = self.audios[item_idx]
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), default_sr
if isinstance(audio, np.ndarray):
return audio, default_sr
assert_never(audio)
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
"""
If :code:`drop_sr=True`, the audio items in this dictionary are updated
to be NumPy arrays which implicitly means that their sampling rate is
the same as the model's expected sampling rate; otherwise, they remain
as :code:`(audio, new_sr)` tuples.
"""
if not self.audios:
return
new_audios = []
for item_idx in range(len(self.audios)):
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
new_audios.append(audio if drop_sr else (audio, new_sr))
self["audio"] = new_audios
class _TokenMatch(NamedTuple):
start_idx: int
end_idx: int
@ -583,11 +477,124 @@ def iter_placeholders(
)
class ProcessorInputs(NamedTuple):
"""Keyword arguments to :meth:`BaseMultiModalProcessor`"""
@dataclass
class ProcessorInputs:
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
prompt_text: str
mm_data: MultiModalDataDict
mm_processor_kwargs: Mapping[str, object]
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class ProcessingCache:
def __init__(self, capacity: int) -> None:
super().__init__()
# DEBUG: Set to None to disable
self.debug_cache_hit_ratio_steps: Optional[int] = None
self._cache = LRUCache[str, Mapping[str,
MultiModalFieldItem]](capacity)
def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps
if not steps:
return
cache_stats = self._cache.stat()
if cache_stats.total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f",
cache_stats.hit_ratio)
def _serialize_item(self, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image):
return obj.tobytes()
# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()
logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))
return pickle.dumps(obj)
def _item_to_bytes(
self,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from self._item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from self._item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = self._serialize_item(key)
value_bytes = self._serialize_item(obj)
yield key_bytes, value_bytes
def _hash_kwargs(self, **kwargs: object) -> str:
hasher = blake3()
for k, v in kwargs.items():
for k_bytes, v_bytes in self._item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)
return hasher.hexdigest()
def get(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> Optional[Mapping[str, MultiModalFieldItem]]:
"""
Get a processed multi-modal item from the cache
according to its dependencies, including:
- The model ID
- The modality of the item
- The original data item passed to the HF processor
- The configuration options of the HF processor
"""
self._maybe_log_cache_stats()
cache_key = self._hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return self._cache.get(cache_key)
def put(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
output_kwargs: Mapping[str, MultiModalFieldItem],
) -> None:
"""
Put a processed multi-modal item into the cache
according to its dependencies (see :meth:`get`).
"""
cache_key = self._hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache.put(cache_key, output_kwargs)
class BaseMultiModalProcessor(ABC):
@ -595,18 +602,24 @@ class BaseMultiModalProcessor(ABC):
Abstract base class to process multi-modal inputs to be used in vLLM.
"""
def __init__(self, ctx: InputProcessingContext) -> None:
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__()
self.ctx = ctx
self.cache = cache
self.enable_sanity_checks = enable_sanity_checks
def __call__(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, mm_processor_kwargs)
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
def _get_hf_processor(self) -> ProcessorMixin:
"""
@ -624,12 +637,21 @@ class BaseMultiModalProcessor(ABC):
) -> MultiModalDataItems:
return MultiModalDataItems.from_dict(mm_data)
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
"""Given the HF-processed data, output the metadata of each field."""
raise NotImplementedError
@abstractmethod
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
"""
Given the original multi-modal items for this modality
@ -651,7 +673,7 @@ class BaseMultiModalProcessor(ABC):
return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
@ -669,7 +691,7 @@ class BaseMultiModalProcessor(ABC):
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
@ -679,39 +701,181 @@ class BaseMultiModalProcessor(ABC):
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
# Not to be confused with `mm_data` in `self.apply`.
# This refers to the data to be passed to HF processor.
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
return self.ctx.call_hf_processor(
hf_processor,
prompt,
processor_data,
mm_processor_kwargs,
self._get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
mm_kwargs,
)
def _apply_hf_processor(
self,
prompt: str,
prompt_text: str,
mm_items: MultiModalDataItems,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
"""
Apply the HF processor on the full prompt text and multi-modal data.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
processor_data, passthrough_data = self._get_processor_data(mm_items)
hf_inputs = self._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
processed_data = self._call_hf_processor(
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
)
hf_inputs.update(passthrough_data)
processed_data.update(passthrough_data)
return hf_inputs
prompt_ids, = processed_data.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
enable_sanity_checks=self.enable_sanity_checks,
)
return prompt_ids, mm_kwargs
def _apply_hf_processor_missing(
self,
prompt_text: str,
mm_missing_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
):
"""
Apply the HF processor on the full prompt text, but only on the
multi-modal data that are missing from the cache.
Note: We pass prompt text and multi-modal data into the HF processor
in separate calls to avoid HF prompt replacement being done for
cached items; instead, we rely on our own prompt replacement logic
for the full text.
"""
mm_missing_counts = mm_missing_data_items.get_item_counts()
prompt_ids, _ = self._apply_hf_processor(
prompt_text=prompt_text,
mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={},
)
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)
_, mm_missing_kwargs = self._apply_hf_processor(
prompt_text=dummy_inputs.prompt_text,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return prompt_ids, mm_missing_kwargs
def _cached_apply_hf_processor(
self,
prompt_text: str,
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
"""
cache = self.cache
model_id = self.ctx.model_config.model
if cache is None or mm_data_items.has_embedding_inputs():
return self._apply_hf_processor(
prompt_text=prompt_text,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
mm_maybe_cached_field_items = {
modality: [
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_data_items.items()
}
mm_missing_idxs = {
modality: [idx for idx, out in enumerate(fields) if out is None]
for modality, fields in mm_maybe_cached_field_items.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
mm_missing_data_items = self._get_mm_items(mm_missing_data)
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
prompt_text=prompt_text,
mm_missing_data_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
mm_missing_next_idx = {
modality: 0
for modality in mm_missing_data_items
}
mm_merged_field_items = dict[str, list[Mapping[str,
MultiModalFieldItem]]]()
for modality, modal_items_lst in mm_maybe_cached_field_items.items():
merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]()
for idx, modal_items in enumerate(modal_items_lst):
if modal_items is None:
modal_items = mm_missing_kwargs.get_items_by_modality(
modality,
mm_missing_next_idx[modality],
)
cache.put(
model_id,
modality,
mm_data_items[modality][idx],
hf_processor_mm_kwargs,
modal_items,
)
mm_missing_next_idx[modality] += 1
merged_modal_items_lst.append(modal_items)
mm_merged_field_items[modality] = merged_modal_items_lst
if self.enable_sanity_checks:
mm_missing_counts = mm_missing_data_items.get_item_counts()
assert all(
item_count == mm_missing_counts[modality]
for modality, item_count in mm_missing_next_idx.items()), dict(
mm_missing_next_idx=mm_missing_next_idx,
mm_missing_counts=mm_missing_counts)
mm_kwargs = MultiModalKwargs.from_items_by_modality(
mm_merged_field_items,
enable_sanity_checks=self.enable_sanity_checks,
)
if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_item_counts()
for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count):
try:
mm_kwargs.get_items_by_modality(modality, item_idx)
except Exception as e:
# Make it easy to set a breakpoint in the debugger
raise e
return prompt_ids, mm_kwargs
def _bind_prompt_replacements(
self,
@ -730,6 +894,10 @@ class BaseMultiModalProcessor(ABC):
tokenizer = self._get_tokenizer()
token_matches = find_token_matches(token_ids, prompt_repls)
mm_match_counts = {
modality: len(matches)
for modality, matches in full_groupby_modality(token_matches)
}
# If the search text does not represent a special token,
# it may have different token IDs in the prompt, because
@ -742,8 +910,8 @@ class BaseMultiModalProcessor(ABC):
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if all(
len(matches) >= mm_item_counts[modality]
for modality, matches in full_groupby_modality(token_matches)
mm_match_counts.get(modality, 0) >= item_count
for modality, item_count in mm_item_counts.items()
): # yapf: disable
token_ids = replace_token_matches(
token_ids,
@ -775,7 +943,7 @@ class BaseMultiModalProcessor(ABC):
self,
prompt_text: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
"""
Process multi-modal inputs to be used in vLLM.
@ -792,20 +960,24 @@ class BaseMultiModalProcessor(ABC):
"""
mm_items = self._get_mm_items(mm_data)
hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
mm_processor_kwargs)
prompt_ids, = hf_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(hf_inputs)
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text,
mm_items,
hf_processor_mm_kwargs,
)
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
mm_processor_kwargs)
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
unbound_prompt_repls = self._get_prompt_replacements(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls)
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts = mm_items.get_item_counts()
all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids, mm_item_counts)
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts)
if all_placeholders:
tokenizer = self._get_tokenizer()
@ -817,7 +989,7 @@ class BaseMultiModalProcessor(ABC):
all_placeholders,
) = self._apply_prompt_replacements(
prompt_ids,
all_prompt_repls,
prompt_repls,
mm_item_counts,
)
@ -855,23 +1027,29 @@ class BaseMultiModalProcessor(ABC):
from vllm.sequence import SequenceData
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
mm_inputs = self.apply(*processor_inputs)
mm_inputs = self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = dict[str, int]()
for modality, placeholders in placeholders_by_modality.items():
num_placeholders = sum(item["length"] for item in placeholders)
max_tokens = mm_max_tokens[modality]
if num_placeholders != max_tokens:
logger.warning(
"The processed dummy data has a total of %d placeholder "
"tokens for the '%s' modality, which is not the expected "
"%d tokens.", num_placeholders, modality, max_tokens)
total_placeholders_by_modality[modality] = num_placeholders
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
if total_len > seq_len:

View File

@ -1,10 +1,9 @@
import functools
from collections import UserDict
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol,
Sequence, Type, TypeVar)
import torch.nn as nn
from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
@ -15,7 +14,7 @@ from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import BaseMultiModalProcessor
from .processing import BaseMultiModalProcessor, ProcessingCache
from .video import VideoPlugin
if TYPE_CHECKING:
@ -23,15 +22,22 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256
N = TypeVar("N", bound=Type[nn.Module])
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
BaseMultiModalProcessor]
"""
Constructs a :class:`MultiModalProcessor` instance from the context.
The processing metadata should be derived from the context.
"""
class MultiModalProcessorFactory(Protocol):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
) -> BaseMultiModalProcessor:
...
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
@ -71,6 +77,8 @@ class MultiModalRegistry:
self._limits_by_model = _MultiModalLimits()
self._processing_cache = ProcessingCache(MM_CACHE_SIZE)
def register_plugin(self, plugin: MultiModalPlugin) -> None:
"""
Register a multi-modal plugin so it can be recognized by vLLM.
@ -328,15 +336,18 @@ class MultiModalRegistry:
return wrapper
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.
"""
def _get_model_cls(self, model_config: "ModelConfig"):
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
return model_cls in self._processor_factories
return model_cls
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.
"""
return self._get_model_cls(model_config) in self._processor_factories
def create_processor(
self,
@ -346,12 +357,11 @@ class MultiModalRegistry:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
model_cls = self._get_model_cls(model_config)
processor_factory = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer)
return processor_factory(ctx)
cache = (None if model_config.disable_mm_preprocessor_cache else
self._processing_cache)
return processor_factory(ctx, cache=cache)

View File

@ -1,25 +1,31 @@
from functools import lru_cache
from typing import Any, cast
from transformers.processing_utils import ProcessorMixin
def get_processor(
processor_name: str,
*args: Any,
trust_remote_code: bool = False,
processor_cls: type[ProcessorMixin] = ProcessorMixin,
**kwargs: Any,
):
"""Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
from transformers.processing_utils import ProcessorMixin
processor_factory = (AutoProcessor
if processor_cls == ProcessorMixin else processor_cls)
try:
processor = AutoProcessor.from_pretrained(
processor = processor_factory.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
**kwargs,
)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.

View File

@ -25,11 +25,11 @@ import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping
from collections.abc import Hashable, Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal,
Dict, Generator, Generic, List, Literal, NamedTuple,
Optional, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4
@ -194,13 +194,29 @@ class Counter:
self.counter = 0
class CacheInfo(NamedTuple):
hits: int
total: int
@property
def hit_ratio(self) -> float:
if self.total == 0:
return 0
return self.hits / self.total
class LRUCache(Generic[_K, _V]):
"""Note: This class is not thread safe!"""
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
self.pinned_items = set[_K]()
self.capacity = capacity
self._hits = 0
self._total = 0
def __contains__(self, key: _K) -> bool:
return key in self.cache
@ -218,6 +234,9 @@ class LRUCache(Generic[_K, _V]):
def __delitem__(self, key: _K) -> None:
self.pop(key)
def stat(self) -> CacheInfo:
return CacheInfo(hits=self._hits, total=self._total)
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)
@ -226,8 +245,12 @@ class LRUCache(Generic[_K, _V]):
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
self._hits += 1
else:
value = default
self._total += 1
return value
def put(self, key: _K, value: _V) -> None: