[3/N] Support and implement merged input processor for LLaVA model (#10676)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-12-07 16:50:58 +08:00 committed by GitHub
parent acf092d348
commit 955fa9533a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 631 additions and 426 deletions

View File

@ -2,7 +2,7 @@ from contextlib import nullcontext
import numpy as np
import pytest
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from transformers import LlavaNextImageProcessor
from vllm.config import ModelConfig
from vllm.multimodal import MultiModalRegistry
@ -14,49 +14,6 @@ def mm_registry():
return MultiModalRegistry()
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, CLIPImageProcessor)
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
limit_mm_per_prompt={"image": 1},
)
mm_registry.init_mm_limits_per_prompt(model_config)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
assert hf_result.keys() == vllm_result.keys()
for key, hf_tensor in hf_result.items():
hf_arr: np.ndarray = hf_tensor.numpy()
vllm_arr: np.ndarray = vllm_result[key].numpy()
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
@ -107,7 +64,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype,
(2, 1, False), (2, 2, True)],
)
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
@ -138,7 +95,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
# NOTE: We don't test zero images since the HF processor doesn't support it
@pytest.mark.parametrize("num_images", [1, 2])
def test_image_mapper_multi(image_assets, mm_registry, num_images):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,

View File

@ -3,50 +3,15 @@ from typing import cast
import pytest
from transformers import BatchFeature
from vllm.multimodal.processing import (PromptReplacement, find_text_matches,
find_token_matches, iter_token_matches,
iter_token_runs, replace_text_matches)
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
# yapf: disable
@pytest.mark.parametrize(
("token_ids", "expected"),
[
([], []),
(
[32000, 32000, 32000],
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[
{ "token_id": 9833, "start_idx": 0, "length": 1 },
{ "token_id": 28747, "start_idx": 1, "length": 1 },
{ "token_id": 32000, "start_idx": 2, "length": 3 },
{ "token_id": 9833, "start_idx": 5, "length": 1 },
{ "token_id": 28747, "start_idx": 6, "length": 1 },
{ "token_id": 32000, "start_idx": 7, "length": 2 },
{ "token_id": 918, "start_idx": 9, "length": 1 },
],
),
],
)
# yapf: enable
def test_iter_token_runs(token_ids, expected):
result = list(iter_token_runs(token_ids))
# Only displayed on error
print("result:", result)
# Manually constructed results
assert [item._asdict() for item in result] == expected
# Invariants
assert sum(run_info.length for run_info in result) == len(token_ids)
# yapf: disable
@pytest.mark.parametrize(
("token_ids", "match_ids", "expected"),
@ -170,13 +135,11 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
result = find_token_matches(
prompt,
[
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
],
)
prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
# Only displayed on error
print("result:", result)
@ -279,13 +242,11 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
result = find_text_matches(
prompt,
[
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
],
)
prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
# Only displayed on error
print("result:", result)
@ -303,7 +264,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"),
("prompt", "target_by_key", "repl_by_key"),
[
(
"Image:<image>Image:<image><image>!",
@ -322,49 +283,201 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# Test multiple repl_count
"pattern_3": ("?", 2),
},
{
# Test no replacement
0: "Image:<image>Image:<image><image>!",
# Test single replacement
1: "<image><image>Image:<image><image>??",
# Test repeated replacement
2: "<image><image><image><image><image>??",
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>??"),
(2, "<image><image><image><image><image>??"),
]
)
# yapf: enable
def test_find_replace_text(
prompt,
target_by_key,
repl_by_key,
expected_by_mm_count,
mm_count,
expected,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
matches = find_text_matches(
prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_text_matches(prompt, prompt_repls)
result = replace_text_matches(
prompt,
[
PromptReplacement(target, *repl_by_key[key]) \
.bind(key, mock_tokenizer)
for key, target in target_by_key.items()
],
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
)
result_by_mm_count = {
mm_count: replace_text_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
)
for mm_count in expected_by_mm_count
}
# Only displayed on error
print("matches:", matches)
print("result_by_mm_count:", result_by_mm_count)
print("result:", result)
# Manually constructed results
assert result_by_mm_count == expected_by_mm_count
assert result == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"),
[
# Tokenized test cases of `test_find_replace_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
{
# We use `<image>` before `Image:` to test matches that
# occur out of order
"pattern_1": [32000],
"pattern_2": [9833, 28747],
"pattern_3": [918],
},
{
# Test whether target is confused with repl_unit
"pattern_1": ([32000, 32000], 1),
# Test empty repl_unit
"pattern_2": ([], 1),
# Test multiple repl_count
"pattern_3": ([1550], 2),
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
]
)
# yapf: enable
def test_find_replace_tokens(
prompt,
target_by_key,
repl_by_key,
mm_count,
expected,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_token_matches(prompt, prompt_repls)
result = replace_token_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
)
# Only displayed on error
print("matches:", matches)
print("result:", result)
# Manually constructed results
assert result == expected
# yapf: disable
@pytest.mark.parametrize(
"repl_by_key",
[
{
"pattern_1": ([32000, 32000], 1),
"pattern_2": ([], 1),
"pattern_3": ([1550], 2),
},
],
)
@pytest.mark.parametrize(
("prompt", "expected"),
[
(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=6,
unit=[32000, 32000],
unit_count=1,
),
],
),
(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=1,
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=5,
unit=[32000, 32000],
unit_count=1,
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=7,
unit=[1550],
unit_count=2,
),
],
),
(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=2,
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=6,
unit=[1550],
unit_count=2,
),
],
),
]
)
def test_iter_placeholders(
repl_by_key,
prompt,
expected,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement([], *repl).bind(key, mock_tokenizer)
for key, repl in repl_by_key.items()
]
result = list(iter_placeholders(prompt_repls, prompt))
# Only displayed on error
print("result:", result)
# Manually constructed results
assert result == expected

View File

@ -2,19 +2,17 @@ from typing import Optional
import torch
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
dummy_data_for_llava,
get_max_llava_image_tokens,
input_processor_for_llava)
create_metadata_for_llava,
dummy_mm_kwargs_for_llava,
get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava,
dummy_mm_kwargs_for_llava)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(

View File

@ -232,19 +232,35 @@ class InputRegistry:
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
if mm_registry.has_processor(model_config):
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = mm_registry.create_processor(model_config, tokenizer)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
model_config)
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
mm_max_tokens)
else:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
else:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
dummy_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts),
**mm_processor_kwargs)
dummy_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts),
**mm_processor_kwargs)
# Having more tokens is over-conservative but otherwise fine
num_tokens = dummy_data.seq_data.prompt_token_ids
@ -257,7 +273,9 @@ class InputRegistry:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if dummy_data.multi_modal_data is not None:
if (dummy_data.multi_modal_data is not None and
not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
for k, v in dummy_data.multi_modal_data.items():
num_items = len(v) if isinstance(v, list) else 1
num_expected = mm_counts[k]

View File

@ -1,17 +1,19 @@
from functools import cached_property
from types import MethodType
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
PretrainedConfig, SiglipVisionConfig)
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig)
from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -19,21 +21,20 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (InputProcessingContext,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
MultiModalProcessor, PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
get_max_clip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
dummy_seq_data_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
input_processor_for_pixtral_hf)
get_max_pixtral_hf_image_tokens)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@ -113,102 +114,86 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_max_llava_image_tokens(ctx)
elif is_list_of(image_data, Image.Image):
image_feature_size = [get_max_llava_image_tokens(ctx)
] * len(image_data)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, PixtralVisionConfig):
# We ignore image_feature_size_override since we have non-uniform
# image sizes for Pixtral
return input_processor_for_pixtral_hf(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)
def create_metadata_for_llava(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
hf_config = ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
def get_repl_count(
mm_items: list[Image],
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
return get_max_llava_image_tokens(ctx)
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[image_token_id],
repl_unit=[image_token_id],
repl_count=get_repl_count),
]),
}
class LlavaProcessor(MultiModalProcessor):
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
return # Already patched
image_processor = hf_processor.image_processor # type: ignore
orig_preprocess = image_processor.preprocess
def preprocess(__self, *args, **kwargs):
hf_inputs = orig_preprocess(*args, **kwargs)
hf_inputs["is_pixtral"] = torch.tensor(True)
return hf_inputs
image_processor.preprocess = MethodType(preprocess, image_processor)
hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> ProcessorMixin:
hf_processor = self.ctx.get_hf_processor()
if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
return hf_processor
def _get_dummy_mm_kwargs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
class LlavaLikeConfig(Protocol):
@ -291,10 +276,11 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
))
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@ -367,38 +353,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return data
def _validate_image_sizes(self, images: List[torch.Tensor],
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(sizes, list):
sizes = [sizes]
total_images = sum(size.numel() // 2 for size in sizes)
if total_images != len(images):
raise ValueError("Mismatch in number of images. "
f"Expected {total_images}, got {len(images)}")
img_idx = 0
for size in sizes:
# Flatten the size tensor to a list of (height, width) pairs
size = size.view(-1, 2).tolist()
for expected_h, expected_w in size:
if img_idx >= len(images):
raise ValueError("Ran out of images before sizes. "
f"{img_idx} >= {len(images)}")
img = images[img_idx]
if img.shape[-2:] != (expected_h, expected_w):
raise ValueError(
"Image size mismatch. Expected "
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
if img.shape[-3] != 3:
raise ValueError("Image channel mismatch. Expected 3, "
f"got {img.shape[-3]}")
img_idx += 1
return images
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
@ -409,9 +367,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
assert isinstance(is_pixtral, torch.Tensor)
if is_pixtral.any():
images = pixel_values
def flatten_to_3d_tensors(item):
@ -434,7 +391,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_sizes(images, image_sizes),
data=images,
)
return LlavaImagePixelInputs(

View File

@ -226,16 +226,16 @@ class MultiModalPlugin(ABC):
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
from vllm.model_executor.models import supports_multimodal
model_cls, _ = get_model_architecture(model_config)
if model_cls not in self._input_mappers:
if not supports_multimodal(model_cls):
return 0
max_mm_tokens = self._max_mm_tokens.get(model_cls)
if max_mm_tokens is None:
raise KeyError(f"No maximum number of multi-modal tokens is given "
f"for model class {model_cls.__name__} in {self}.")
return 0
if callable(max_mm_tokens):
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
@ -326,26 +326,47 @@ class MultiModalPlaceholderMap:
src_ranges = []
dest_ranges = []
"""
if (not seq_group.multi_modal_data
or not seq_group.multi_modal_placeholders):
return seq_group.multi_modal_data, {}
seq_mm_data = seq_group.multi_modal_data
seq_mm_placeholders = seq_group.multi_modal_placeholders
mm_data = {**seq_group.multi_modal_data}
placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
if not seq_mm_data or not seq_mm_placeholders:
return seq_mm_data, {}
# For merged processor, we directly use mm_kwargs as mm_data
if isinstance(seq_mm_data, MultiModalKwargs):
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
for modality, placeholders in seq_mm_placeholders.items():
placeholder_map = MultiModalPlaceholderMap()
if positions:
placeholder_map.append_items_from_seq_group(
positions,
# Dummy, since we don't care about intersecting items
[None] * len(placeholders),
placeholders,
)
placeholder_maps[modality] = placeholder_map
return seq_mm_data, placeholder_maps
mm_data = {**seq_mm_data}
placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
MultiModalPlaceholderMap)
for (
modality,
placeholders,
) in seq_group.multi_modal_placeholders.items():
for modality, placeholders in seq_mm_placeholders.items():
mm_items = mm_data.pop(modality)
if not isinstance(mm_items, list):
mm_items = [mm_items]
if positions:
intersecting_items = placeholder_maps[
modality].append_items_from_seq_group(
positions, mm_items, placeholders)
intersecting_items = placeholder_maps[modality] \
.append_items_from_seq_group(
positions,
mm_items,
placeholders,
)
if intersecting_items:
mm_data[modality] = intersecting_items

View File

@ -3,14 +3,13 @@ from abc import ABC, abstractmethod
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import lru_cache
from itertools import groupby
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np
from transformers import BatchFeature
import torch
from transformers import BatchFeature, ProcessorMixin
from typing_extensions import TypeAlias, TypedDict
from vllm.inputs import InputProcessingContext
from vllm.inputs import DummyData, InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
@ -256,63 +255,6 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
return multi_data
class _TokenRun(NamedTuple):
token_id: int
start_idx: int
length: int
def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
"""
Yield the starting index and length of each run of tokens that are the same.
"""
start_idx = 0
for token_id, it in groupby(token_ids):
length = sum(1 for _ in it)
yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length)
start_idx += length
class _PlaceholderInfo(NamedTuple):
modality: str
offset: int
length: int
def to_range(self) -> PlaceholderRange:
return PlaceholderRange(offset=self.offset, length=self.length)
def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
token_ids: list[int],
*,
min_placeholder_count: int,
) -> Iterable[_PlaceholderInfo]:
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
placeholder_ids_by_modality = {
modality: {
token_id
for prompt_repl in repls
for token_id in prompt_repl.repl_unit.token_ids
}
for modality, repls in full_groupby_modality(prompt_repls)
}
for run_info in iter_token_runs(token_ids):
if run_info.length > min_placeholder_count:
for (modality,
placeholder_ids) in placeholder_ids_by_modality.items():
if run_info.token_id in placeholder_ids:
yield _PlaceholderInfo(
modality=modality,
offset=run_info.start_idx,
length=run_info.length,
)
class _TokenMatch(NamedTuple):
start_idx: int
end_idx: int
@ -353,13 +295,9 @@ class _PromptReplacementMatch(ABC, Generic[_T, _S]):
def end_idx(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def get_repl(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> _S:
def repl_unit(self) -> _S:
raise NotImplementedError
def __repr__(self) -> str:
@ -380,15 +318,9 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
def end_idx(self) -> int:
return self.match.end_idx
def get_repl(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> list[int]:
prompt_repl = self.prompt_repl
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
return prompt_repl.repl_unit.token_ids * count
@property
def repl_unit(self) -> list[int]:
return self.prompt_repl.repl_unit.token_ids
@dataclass(repr=False)
@ -404,15 +336,26 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
def end_idx(self) -> int:
return self.match.end()
def get_repl(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> str:
prompt_repl = self.prompt_repl
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
return prompt_repl.repl_unit.text * count
@property
def repl_unit(self) -> str:
return self.prompt_repl.repl_unit.text
class _PlaceholderInfo(NamedTuple):
modality: str
start_idx: int
unit: list[int]
unit_count: int
@property
def length(self) -> int:
return len(self.unit) * self.unit_count
def to_range(self) -> PlaceholderRange:
return PlaceholderRange(
offset=self.start_idx,
length=self.length,
)
def find_token_matches(
@ -447,15 +390,17 @@ def _resolve_matches(
Resolve :code:`matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
for match in matches:
num_matches_by_idx[match.start_idx:match.end_idx] += 1
seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \
= [None] * len(prompt)
duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1)
if len(duplicate_matches_idxs) > 0:
raise ValueError("Unable to find a unique replacement "
f"at indices={duplicate_matches_idxs} "
f"of prompt={prompt}")
for match in matches:
for idx in range(match.start_idx, match.end_idx):
if seen_matches[idx] is not None:
raise ValueError("Found overlapping matches "
f"({seen_matches[idx]} and {match}) "
f"at index={idx} of prompt={prompt}")
seen_matches[idx] = match
return sorted(matches, key=lambda x: x.start_idx)
@ -480,9 +425,12 @@ def _replace_matches(
start_idx = match.start_idx
end_idx = match.end_idx
repl_ids = match.get_repl(mm_items, hf_inputs, item_idx)
repl_unit = match.repl_unit
repl_info = match.prompt_repl
repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx)
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids)
out_seqs.append(prompt[prev_end_idx:start_idx] +
repl_unit * repl_count)
prev_end_idx = end_idx
next_idx_by_modality[modality] += 1
@ -531,7 +479,57 @@ def replace_text_matches(
return "".join(texts)
class MultiModalProcessor:
def _merge_placeholder_matches(
matches: Iterable[_PromptReplacementTokenMatch],
) -> Iterable[_PromptReplacementTokenMatch]:
current_match = None
for match in sorted(matches, key=lambda x: x.start_idx):
if current_match is None:
current_match = match
elif (current_match.prompt_repl == match.prompt_repl
and current_match.end_idx == match.start_idx):
current_match = _PromptReplacementTokenMatch(
current_match.prompt_repl,
match=_TokenMatch(current_match.start_idx, match.end_idx),
)
else:
yield current_match
current_match = match
if current_match is not None:
yield current_match
def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
prompt: list[int],
*,
min_unit_count: int = 1,
) -> Iterable[_PlaceholderInfo]:
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
if min_unit_count <= 0:
raise ValueError("`min_unit_count` must be a positive integer")
matches = (_PromptReplacementTokenMatch(prompt_repl, match)
for prompt_repl in prompt_repls
if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0
for match in iter_token_matches(prompt, repl_unit))
for match in _merge_placeholder_matches(matches):
unit = match.repl_unit
placeholder = _PlaceholderInfo(
modality=match.modality,
start_idx=match.start_idx,
unit=unit,
unit_count=(match.end_idx - match.start_idx) // len(unit),
)
if placeholder.unit_count >= min_unit_count:
yield placeholder
class MultiModalProcessor(ABC):
"""
Helper class to process multi-modal inputs to be used in vLLM.
"""
@ -546,6 +544,12 @@ class MultiModalProcessor:
self.ctx = ctx
self.metadata = metadata
def _get_hf_processor(self) -> ProcessorMixin:
return self.ctx.get_hf_processor()
def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
def __call__(
self,
prompt: str,
@ -562,13 +566,13 @@ class MultiModalProcessor:
# To avoid false positives from multi-input when detecting
# whether placeholder tokens have been inserted, in case
# the target sequence is a subset of the replacement tokens
min_placeholder_count: int = 16,
min_unit_count: int = 16,
) -> list[_PlaceholderInfo]:
return list(
iter_placeholders(
all_prompt_repls,
new_token_ids,
min_placeholder_count=min_placeholder_count,
min_unit_count=min_unit_count,
))
def _apply_hf_processor(
@ -577,19 +581,49 @@ class MultiModalProcessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self.ctx.get_hf_processor()
hf_processor = self._get_hf_processor()
return hf_processor(
text=prompt, # type: ignore
**mm_data,
**mm_processor_kwargs,
)
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
try:
hf_inputs = hf_processor(
text=prompt, # type: ignore
**processor_data,
**mm_processor_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
raise RuntimeError(
f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={mm_processor_kwargs}") from exc
hf_inputs.update(passthrough_data)
return hf_inputs
def _bind_prompt_replacements(
self,
mm_data: MultiModalDataDict,
) -> list[_BoundPromptReplacement[Any]]:
tokenizer = self.ctx.tokenizer
tokenizer = self._get_tokenizer()
return [
prompt_repl.bind(modality, tokenizer)
@ -604,7 +638,7 @@ class MultiModalProcessor:
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
tokenizer = self.ctx.tokenizer
tokenizer = self._get_tokenizer()
mm_items = to_multi_format(mm_data)
token_matches = find_token_matches(token_ids, prompt_repls)
@ -620,7 +654,7 @@ class MultiModalProcessor:
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if all(
len(matches) >= len(mm_data[modality])
len(matches) >= len(mm_items[modality])
for modality, matches in full_groupby_modality(token_matches)
): # yapf: disable
token_ids = replace_token_matches(
@ -648,15 +682,6 @@ class MultiModalProcessor:
placeholders = self._find_placeholders(matched_repls, token_ids)
# Sanity check
assert len(placeholders) == len(matched_repls), dict(
# Log this information for easier debugging
text=text,
token_ids=token_ids,
placeholders=placeholders,
matched_repls=matched_repls,
)
return token_ids, text, placeholders
def apply(
@ -678,7 +703,7 @@ class MultiModalProcessor:
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
tokenizer = self.ctx.tokenizer
tokenizer = self._get_tokenizer()
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
mm_processor_kwargs)
@ -717,3 +742,59 @@ class MultiModalProcessor:
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)
@abstractmethod
def _get_dummy_mm_kwargs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
"""
Build the input that corresponds to `mm_max_tokens` in
:meth:`get_dummy_data`.
"""
raise NotImplementedError
def get_dummy_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_max_tokens: Mapping[str, int],
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
tokenizer = self._get_tokenizer()
mm_placeholders = dict[str, _PlaceholderInfo]()
offset = 0
for modality, max_tokens in mm_max_tokens.items():
if max_tokens == 0:
continue
metadata = self.metadata[modality]
repl = metadata.prompt_repls[0].bind(modality, tokenizer)
repl_token_ids = repl.repl_unit.token_ids
placeholders = _PlaceholderInfo(
modality=modality,
start_idx=offset,
unit=repl_token_ids,
unit_count=max_tokens // len(repl_token_ids),
)
mm_placeholders[modality] = placeholders
offset += placeholders.length
prompt_token_ids = flatten_2d_lists(
[p.unit * p.unit_count for p in mm_placeholders.values()])
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
multi_modal_placeholders={
modality: [p.to_range()]
for modality, p in mm_placeholders.items()
},
)

View File

@ -15,7 +15,7 @@ from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import MultiModalProcessor
from .processing import MultiModalProcessingMetadata, MultiModalProcessor
from .video import VideoPlugin
if TYPE_CHECKING:
@ -200,6 +200,27 @@ class MultiModalRegistry:
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens from each modality
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
limits_per_plugin = self._limits_by_model[model_config]
return {
key: (limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items()
}
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
@ -210,11 +231,7 @@ class MultiModalRegistry:
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
limits_per_plugin = self._limits_by_model[model_config]
return sum((limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items())
return sum(self.get_max_tokens_by_modality(model_config).values())
def init_mm_limits_per_prompt(
self,
@ -270,7 +287,8 @@ class MultiModalRegistry:
factory: MultiModalProcessorFactory,
):
"""
Register a multi-modal processor to a model class.
Register a multi-modal processor to a model class. The processor
is constructed lazily, hence a factory method should be passed.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
@ -293,6 +311,41 @@ class MultiModalRegistry:
return wrapper
def register_processor_by_metadata(
self,
metadata_factory: Callable[[InputProcessingContext],
MultiModalProcessingMetadata],
get_dummy_mm_kwargs: Callable[
[InputProcessingContext, Mapping[str, int]], MultiModalKwargs],
):
"""
Convenience method to register a multi-modal processor to a model class
according to a function that constructs its metadata.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
class ConcreteMultiModalProcessor(MultiModalProcessor):
def _get_dummy_mm_kwargs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return get_dummy_mm_kwargs(self.ctx, mm_counts)
def factory(ctx: InputProcessingContext):
return ConcreteMultiModalProcessor(
ctx=ctx,
metadata=metadata_factory(ctx),
)
return self.register_processor(factory)
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.

View File

@ -12,6 +12,7 @@ class MMInputMapper:
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)

View File

@ -7,7 +7,8 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
@ -101,10 +102,15 @@ class Processor:
self.generation_config_fields, eos_token_id)
# Preprocess multi-modal data
mm_inputs = self.mm_input_mapper.process_inputs(
decoder_inputs.multi_modal_data,
decoder_inputs.mm_processor_kwargs) if len(
decoder_inputs.multi_modal_data) > 0 else None
if len(decoder_inputs.multi_modal_data) == 0:
mm_inputs = None
elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
mm_inputs = [decoder_inputs.multi_modal_data]
else:
mm_inputs = self.mm_input_mapper.process_inputs(
decoder_inputs.multi_modal_data,
decoder_inputs.mm_processor_kwargs,
)
# Make Request for Detokenizer.
detokenizer_request = DetokenizerRequest(