mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:37:12 +08:00
[3/N] Support and implement merged input processor for LLaVA model (#10676)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
acf092d348
commit
955fa9533a
@ -2,7 +2,7 @@ from contextlib import nullcontext
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
|
||||
from transformers import LlavaNextImageProcessor
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
@ -14,49 +14,6 @@ def mm_registry():
|
||||
return MultiModalRegistry()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
|
||||
assert isinstance(hf_processor, CLIPImageProcessor)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
task="auto",
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
for asset in image_assets:
|
||||
image = rescale_image_size(asset.pil_image, size_factor)
|
||||
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
vllm_result = mm_registry.map_input(
|
||||
model_config,
|
||||
{"image": image},
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
for key, hf_tensor in hf_result.items():
|
||||
hf_arr: np.ndarray = hf_tensor.numpy()
|
||||
vllm_arr: np.ndarray = vllm_result[key].numpy()
|
||||
|
||||
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
||||
@ -107,7 +64,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
||||
(2, 1, False), (2, 2, True)],
|
||||
)
|
||||
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
@ -138,7 +95,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
||||
# NOTE: We don't test zero images since the HF processor doesn't support it
|
||||
@pytest.mark.parametrize("num_images", [1, 2])
|
||||
def test_image_mapper_multi(image_assets, mm_registry, num_images):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
|
||||
@ -3,50 +3,15 @@ from typing import cast
|
||||
import pytest
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.multimodal.processing import (PromptReplacement, find_text_matches,
|
||||
find_token_matches, iter_token_matches,
|
||||
iter_token_runs, replace_text_matches)
|
||||
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_placeholders, iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import full_groupby
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("token_ids", "expected"),
|
||||
[
|
||||
([], []),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[
|
||||
{ "token_id": 9833, "start_idx": 0, "length": 1 },
|
||||
{ "token_id": 28747, "start_idx": 1, "length": 1 },
|
||||
{ "token_id": 32000, "start_idx": 2, "length": 3 },
|
||||
{ "token_id": 9833, "start_idx": 5, "length": 1 },
|
||||
{ "token_id": 28747, "start_idx": 6, "length": 1 },
|
||||
{ "token_id": 32000, "start_idx": 7, "length": 2 },
|
||||
{ "token_id": 918, "start_idx": 9, "length": 1 },
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_iter_token_runs(token_ids, expected):
|
||||
result = list(iter_token_runs(token_ids))
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert [item._asdict() for item in result] == expected
|
||||
|
||||
# Invariants
|
||||
assert sum(run_info.length for run_info in result) == len(token_ids)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("token_ids", "match_ids", "expected"),
|
||||
@ -170,13 +135,11 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||
# Should not be used since there is nothing to convert to token IDs
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
result = find_token_matches(
|
||||
prompt,
|
||||
[
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
],
|
||||
)
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_token_matches(prompt, prompt_repls)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@ -279,13 +242,11 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
result = find_text_matches(
|
||||
prompt,
|
||||
[
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
],
|
||||
)
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_text_matches(prompt, prompt_repls)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@ -303,7 +264,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"),
|
||||
("prompt", "target_by_key", "repl_by_key"),
|
||||
[
|
||||
(
|
||||
"Image:<image>Image:<image><image>!",
|
||||
@ -322,49 +283,201 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
# Test multiple repl_count
|
||||
"pattern_3": ("?", 2),
|
||||
},
|
||||
{
|
||||
# Test no replacement
|
||||
0: "Image:<image>Image:<image><image>!",
|
||||
# Test single replacement
|
||||
1: "<image><image>Image:<image><image>??",
|
||||
# Test repeated replacement
|
||||
2: "<image><image><image><image><image>??",
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, "Image:<image>Image:<image><image>!"),
|
||||
(1, "<image><image>Image:<image><image>??"),
|
||||
(2, "<image><image><image><image><image>??"),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_replace_text(
|
||||
prompt,
|
||||
target_by_key,
|
||||
repl_by_key,
|
||||
expected_by_mm_count,
|
||||
mm_count,
|
||||
expected,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
matches = find_text_matches(
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_text_matches(prompt, prompt_repls)
|
||||
|
||||
result = replace_text_matches(
|
||||
prompt,
|
||||
[
|
||||
PromptReplacement(target, *repl_by_key[key]) \
|
||||
.bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
],
|
||||
matches,
|
||||
{key: list(range(mm_count))
|
||||
for key in repl_by_key},
|
||||
BatchFeature(),
|
||||
)
|
||||
result_by_mm_count = {
|
||||
mm_count: replace_text_matches(
|
||||
prompt,
|
||||
matches,
|
||||
{key: list(range(mm_count))
|
||||
for key in repl_by_key},
|
||||
BatchFeature(),
|
||||
)
|
||||
for mm_count in expected_by_mm_count
|
||||
}
|
||||
|
||||
# Only displayed on error
|
||||
print("matches:", matches)
|
||||
print("result_by_mm_count:", result_by_mm_count)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result_by_mm_count == expected_by_mm_count
|
||||
assert result == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key"),
|
||||
[
|
||||
# Tokenized test cases of `test_find_replace_text`
|
||||
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||
(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
{
|
||||
# We use `<image>` before `Image:` to test matches that
|
||||
# occur out of order
|
||||
"pattern_1": [32000],
|
||||
"pattern_2": [9833, 28747],
|
||||
"pattern_3": [918],
|
||||
},
|
||||
{
|
||||
# Test whether target is confused with repl_unit
|
||||
"pattern_1": ([32000, 32000], 1),
|
||||
# Test empty repl_unit
|
||||
"pattern_2": ([], 1),
|
||||
# Test multiple repl_count
|
||||
"pattern_3": ([1550], 2),
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
|
||||
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
|
||||
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_replace_tokens(
|
||||
prompt,
|
||||
target_by_key,
|
||||
repl_by_key,
|
||||
mm_count,
|
||||
expected,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_token_matches(prompt, prompt_repls)
|
||||
|
||||
result = replace_token_matches(
|
||||
prompt,
|
||||
matches,
|
||||
{key: list(range(mm_count))
|
||||
for key in repl_by_key},
|
||||
BatchFeature(),
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("matches:", matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"repl_by_key",
|
||||
[
|
||||
{
|
||||
"pattern_1": ([32000, 32000], 1),
|
||||
"pattern_2": ([], 1),
|
||||
"pattern_3": ([1550], 2),
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "expected"),
|
||||
[
|
||||
(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=6,
|
||||
unit=[32000, 32000],
|
||||
unit_count=1,
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
unit=[32000, 32000],
|
||||
unit_count=1,
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=5,
|
||||
unit=[32000, 32000],
|
||||
unit_count=1,
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=7,
|
||||
unit=[1550],
|
||||
unit_count=2,
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
unit=[32000, 32000],
|
||||
unit_count=2,
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=6,
|
||||
unit=[1550],
|
||||
unit_count=2,
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_iter_placeholders(
|
||||
repl_by_key,
|
||||
prompt,
|
||||
expected,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement([], *repl).bind(key, mock_tokenizer)
|
||||
for key, repl in repl_by_key.items()
|
||||
]
|
||||
|
||||
result = list(iter_placeholders(prompt_repls, prompt))
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
@ -2,19 +2,17 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
||||
dummy_data_for_llava,
|
||||
get_max_llava_image_tokens,
|
||||
input_processor_for_llava)
|
||||
create_metadata_for_llava,
|
||||
dummy_mm_kwargs_for_llava,
|
||||
get_max_llava_image_tokens)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||
@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava,
|
||||
dummy_mm_kwargs_for_llava)
|
||||
class MyLlava(LlavaForConditionalGeneration):
|
||||
|
||||
def compute_logits(
|
||||
|
||||
@ -232,19 +232,35 @@ class InputRegistry:
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
if is_encoder_data:
|
||||
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
|
||||
if mm_registry.has_processor(model_config):
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
processor = mm_registry.create_processor(model_config, tokenizer)
|
||||
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
|
||||
model_config)
|
||||
|
||||
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
|
||||
mm_max_tokens)
|
||||
else:
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
if is_encoder_data:
|
||||
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
|
||||
else:
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
||||
|
||||
dummy_data = dummy_factory(InputContext(model_config), seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
**mm_processor_kwargs)
|
||||
dummy_data = dummy_factory(InputContext(model_config), seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
**mm_processor_kwargs)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
num_tokens = dummy_data.seq_data.prompt_token_ids
|
||||
@ -257,7 +273,9 @@ class InputRegistry:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(num_tokens)} tokens instead.")
|
||||
if dummy_data.multi_modal_data is not None:
|
||||
|
||||
if (dummy_data.multi_modal_data is not None and
|
||||
not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
|
||||
for k, v in dummy_data.multi_modal_data.items():
|
||||
num_items = len(v) if isinstance(v, list) else 1
|
||||
num_expected = mm_counts[k]
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
from functools import cached_property
|
||||
from types import MethodType
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
||||
PretrainedConfig, SiglipVisionConfig)
|
||||
from PIL.Image import Image
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin, SiglipVisionConfig)
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -19,21 +21,20 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
ModalityProcessingMetadata,
|
||||
MultiModalProcessingMetadata,
|
||||
MultiModalProcessor, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||
input_processor_for_clip)
|
||||
get_max_clip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
dummy_seq_data_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
input_processor_for_pixtral_hf)
|
||||
get_max_pixtral_hf_image_tokens)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||
input_processor_for_siglip)
|
||||
get_max_siglip_image_tokens)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
@ -113,102 +114,86 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = [get_max_llava_image_tokens(ctx)
|
||||
] * len(image_data)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
image_feature_size = [item.shape[1] for item in image_data]
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
# We ignore image_feature_size_override since we have non-uniform
|
||||
# image sizes for Pixtral
|
||||
return input_processor_for_pixtral_hf(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
hf_processor = ctx.get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||
is_pixtral = isinstance(hf_processor, PixtralProcessor)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
return MultiModalKwargs(
|
||||
**hf_inputs,
|
||||
is_pixtral=torch.tensor(is_pixtral),
|
||||
)
|
||||
|
||||
|
||||
def create_metadata_for_llava(
|
||||
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_repl_count(
|
||||
mm_items: list[Image],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> int:
|
||||
return get_max_llava_image_tokens(ctx)
|
||||
|
||||
return {
|
||||
"image":
|
||||
ModalityProcessingMetadata(prompt_repls=[
|
||||
PromptReplacement(target=[image_token_id],
|
||||
repl_unit=[image_token_id],
|
||||
repl_count=get_repl_count),
|
||||
]),
|
||||
}
|
||||
|
||||
|
||||
class LlavaProcessor(MultiModalProcessor):
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
return # Already patched
|
||||
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
orig_preprocess = image_processor.preprocess
|
||||
|
||||
def preprocess(__self, *args, **kwargs):
|
||||
hf_inputs = orig_preprocess(*args, **kwargs)
|
||||
hf_inputs["is_pixtral"] = torch.tensor(True)
|
||||
return hf_inputs
|
||||
|
||||
image_processor.preprocess = MethodType(preprocess, image_processor)
|
||||
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_dummy_mm_kwargs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
@ -291,10 +276,11 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_llava(ctx),
|
||||
))
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@ -367,38 +353,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return data
|
||||
|
||||
def _validate_image_sizes(self, images: List[torch.Tensor],
|
||||
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
if not isinstance(sizes, list):
|
||||
sizes = [sizes]
|
||||
|
||||
total_images = sum(size.numel() // 2 for size in sizes)
|
||||
if total_images != len(images):
|
||||
raise ValueError("Mismatch in number of images. "
|
||||
f"Expected {total_images}, got {len(images)}")
|
||||
img_idx = 0
|
||||
for size in sizes:
|
||||
# Flatten the size tensor to a list of (height, width) pairs
|
||||
size = size.view(-1, 2).tolist()
|
||||
for expected_h, expected_w in size:
|
||||
if img_idx >= len(images):
|
||||
raise ValueError("Ran out of images before sizes. "
|
||||
f"{img_idx} >= {len(images)}")
|
||||
img = images[img_idx]
|
||||
if img.shape[-2:] != (expected_h, expected_w):
|
||||
raise ValueError(
|
||||
"Image size mismatch. Expected "
|
||||
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
|
||||
if img.shape[-3] != 3:
|
||||
raise ValueError("Image channel mismatch. Expected 3, "
|
||||
f"got {img.shape[-3]}")
|
||||
img_idx += 1
|
||||
return images
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
@ -409,9 +367,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Case for models like PixtralHF that have dynamic image sizes
|
||||
# so we need to produce a list of tensors
|
||||
if image_sizes is not None:
|
||||
assert isinstance(is_pixtral, torch.Tensor)
|
||||
if is_pixtral.any():
|
||||
images = pixel_values
|
||||
|
||||
def flatten_to_3d_tensors(item):
|
||||
@ -434,7 +391,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_sizes(images, image_sizes),
|
||||
data=images,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
|
||||
@ -226,16 +226,16 @@ class MultiModalPlugin(ABC):
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
if model_cls not in self._input_mappers:
|
||||
if not supports_multimodal(model_cls):
|
||||
return 0
|
||||
|
||||
max_mm_tokens = self._max_mm_tokens.get(model_cls)
|
||||
if max_mm_tokens is None:
|
||||
raise KeyError(f"No maximum number of multi-modal tokens is given "
|
||||
f"for model class {model_cls.__name__} in {self}.")
|
||||
return 0
|
||||
|
||||
if callable(max_mm_tokens):
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
@ -326,26 +326,47 @@ class MultiModalPlaceholderMap:
|
||||
src_ranges = []
|
||||
dest_ranges = []
|
||||
"""
|
||||
if (not seq_group.multi_modal_data
|
||||
or not seq_group.multi_modal_placeholders):
|
||||
return seq_group.multi_modal_data, {}
|
||||
seq_mm_data = seq_group.multi_modal_data
|
||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||
|
||||
mm_data = {**seq_group.multi_modal_data}
|
||||
placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
|
||||
if not seq_mm_data or not seq_mm_placeholders:
|
||||
return seq_mm_data, {}
|
||||
|
||||
# For merged processor, we directly use mm_kwargs as mm_data
|
||||
if isinstance(seq_mm_data, MultiModalKwargs):
|
||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||
|
||||
for modality, placeholders in seq_mm_placeholders.items():
|
||||
placeholder_map = MultiModalPlaceholderMap()
|
||||
|
||||
if positions:
|
||||
placeholder_map.append_items_from_seq_group(
|
||||
positions,
|
||||
# Dummy, since we don't care about intersecting items
|
||||
[None] * len(placeholders),
|
||||
placeholders,
|
||||
)
|
||||
|
||||
placeholder_maps[modality] = placeholder_map
|
||||
|
||||
return seq_mm_data, placeholder_maps
|
||||
|
||||
mm_data = {**seq_mm_data}
|
||||
placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
|
||||
MultiModalPlaceholderMap)
|
||||
|
||||
for (
|
||||
modality,
|
||||
placeholders,
|
||||
) in seq_group.multi_modal_placeholders.items():
|
||||
for modality, placeholders in seq_mm_placeholders.items():
|
||||
mm_items = mm_data.pop(modality)
|
||||
if not isinstance(mm_items, list):
|
||||
mm_items = [mm_items]
|
||||
|
||||
if positions:
|
||||
intersecting_items = placeholder_maps[
|
||||
modality].append_items_from_seq_group(
|
||||
positions, mm_items, placeholders)
|
||||
intersecting_items = placeholder_maps[modality] \
|
||||
.append_items_from_seq_group(
|
||||
positions,
|
||||
mm_items,
|
||||
placeholders,
|
||||
)
|
||||
|
||||
if intersecting_items:
|
||||
mm_data[modality] = intersecting_items
|
||||
|
||||
@ -3,14 +3,13 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from itertools import groupby
|
||||
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import BatchFeature
|
||||
import torch
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.inputs import DummyData, InputProcessingContext
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
||||
|
||||
@ -256,63 +255,6 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
|
||||
return multi_data
|
||||
|
||||
|
||||
class _TokenRun(NamedTuple):
|
||||
token_id: int
|
||||
|
||||
start_idx: int
|
||||
length: int
|
||||
|
||||
|
||||
def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
|
||||
"""
|
||||
Yield the starting index and length of each run of tokens that are the same.
|
||||
"""
|
||||
start_idx = 0
|
||||
|
||||
for token_id, it in groupby(token_ids):
|
||||
length = sum(1 for _ in it)
|
||||
yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length)
|
||||
|
||||
start_idx += length
|
||||
|
||||
|
||||
class _PlaceholderInfo(NamedTuple):
|
||||
modality: str
|
||||
offset: int
|
||||
length: int
|
||||
|
||||
def to_range(self) -> PlaceholderRange:
|
||||
return PlaceholderRange(offset=self.offset, length=self.length)
|
||||
|
||||
|
||||
def iter_placeholders(
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
token_ids: list[int],
|
||||
*,
|
||||
min_placeholder_count: int,
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
|
||||
placeholder_ids_by_modality = {
|
||||
modality: {
|
||||
token_id
|
||||
for prompt_repl in repls
|
||||
for token_id in prompt_repl.repl_unit.token_ids
|
||||
}
|
||||
for modality, repls in full_groupby_modality(prompt_repls)
|
||||
}
|
||||
|
||||
for run_info in iter_token_runs(token_ids):
|
||||
if run_info.length > min_placeholder_count:
|
||||
for (modality,
|
||||
placeholder_ids) in placeholder_ids_by_modality.items():
|
||||
if run_info.token_id in placeholder_ids:
|
||||
yield _PlaceholderInfo(
|
||||
modality=modality,
|
||||
offset=run_info.start_idx,
|
||||
length=run_info.length,
|
||||
)
|
||||
|
||||
|
||||
class _TokenMatch(NamedTuple):
|
||||
start_idx: int
|
||||
end_idx: int
|
||||
@ -353,13 +295,9 @@ class _PromptReplacementMatch(ABC, Generic[_T, _S]):
|
||||
def end_idx(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_repl(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> _S:
|
||||
def repl_unit(self) -> _S:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -380,15 +318,9 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
|
||||
def end_idx(self) -> int:
|
||||
return self.match.end_idx
|
||||
|
||||
def get_repl(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> list[int]:
|
||||
prompt_repl = self.prompt_repl
|
||||
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
||||
return prompt_repl.repl_unit.token_ids * count
|
||||
@property
|
||||
def repl_unit(self) -> list[int]:
|
||||
return self.prompt_repl.repl_unit.token_ids
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@ -404,15 +336,26 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
|
||||
def end_idx(self) -> int:
|
||||
return self.match.end()
|
||||
|
||||
def get_repl(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> str:
|
||||
prompt_repl = self.prompt_repl
|
||||
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
||||
return prompt_repl.repl_unit.text * count
|
||||
@property
|
||||
def repl_unit(self) -> str:
|
||||
return self.prompt_repl.repl_unit.text
|
||||
|
||||
|
||||
class _PlaceholderInfo(NamedTuple):
|
||||
modality: str
|
||||
start_idx: int
|
||||
unit: list[int]
|
||||
unit_count: int
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self.unit) * self.unit_count
|
||||
|
||||
def to_range(self) -> PlaceholderRange:
|
||||
return PlaceholderRange(
|
||||
offset=self.start_idx,
|
||||
length=self.length,
|
||||
)
|
||||
|
||||
|
||||
def find_token_matches(
|
||||
@ -447,15 +390,17 @@ def _resolve_matches(
|
||||
Resolve :code:`matches` to ensure that there are no overlapping matches,
|
||||
and sort them such that earlier matches take priority over later ones.
|
||||
"""
|
||||
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
|
||||
for match in matches:
|
||||
num_matches_by_idx[match.start_idx:match.end_idx] += 1
|
||||
seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \
|
||||
= [None] * len(prompt)
|
||||
|
||||
duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1)
|
||||
if len(duplicate_matches_idxs) > 0:
|
||||
raise ValueError("Unable to find a unique replacement "
|
||||
f"at indices={duplicate_matches_idxs} "
|
||||
f"of prompt={prompt}")
|
||||
for match in matches:
|
||||
for idx in range(match.start_idx, match.end_idx):
|
||||
if seen_matches[idx] is not None:
|
||||
raise ValueError("Found overlapping matches "
|
||||
f"({seen_matches[idx]} and {match}) "
|
||||
f"at index={idx} of prompt={prompt}")
|
||||
|
||||
seen_matches[idx] = match
|
||||
|
||||
return sorted(matches, key=lambda x: x.start_idx)
|
||||
|
||||
@ -480,9 +425,12 @@ def _replace_matches(
|
||||
|
||||
start_idx = match.start_idx
|
||||
end_idx = match.end_idx
|
||||
repl_ids = match.get_repl(mm_items, hf_inputs, item_idx)
|
||||
repl_unit = match.repl_unit
|
||||
repl_info = match.prompt_repl
|
||||
repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx)
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids)
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] +
|
||||
repl_unit * repl_count)
|
||||
prev_end_idx = end_idx
|
||||
next_idx_by_modality[modality] += 1
|
||||
|
||||
@ -531,7 +479,57 @@ def replace_text_matches(
|
||||
return "".join(texts)
|
||||
|
||||
|
||||
class MultiModalProcessor:
|
||||
def _merge_placeholder_matches(
|
||||
matches: Iterable[_PromptReplacementTokenMatch],
|
||||
) -> Iterable[_PromptReplacementTokenMatch]:
|
||||
current_match = None
|
||||
|
||||
for match in sorted(matches, key=lambda x: x.start_idx):
|
||||
if current_match is None:
|
||||
current_match = match
|
||||
elif (current_match.prompt_repl == match.prompt_repl
|
||||
and current_match.end_idx == match.start_idx):
|
||||
current_match = _PromptReplacementTokenMatch(
|
||||
current_match.prompt_repl,
|
||||
match=_TokenMatch(current_match.start_idx, match.end_idx),
|
||||
)
|
||||
else:
|
||||
yield current_match
|
||||
current_match = match
|
||||
|
||||
if current_match is not None:
|
||||
yield current_match
|
||||
|
||||
|
||||
def iter_placeholders(
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
prompt: list[int],
|
||||
*,
|
||||
min_unit_count: int = 1,
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
|
||||
if min_unit_count <= 0:
|
||||
raise ValueError("`min_unit_count` must be a positive integer")
|
||||
|
||||
matches = (_PromptReplacementTokenMatch(prompt_repl, match)
|
||||
for prompt_repl in prompt_repls
|
||||
if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0
|
||||
for match in iter_token_matches(prompt, repl_unit))
|
||||
|
||||
for match in _merge_placeholder_matches(matches):
|
||||
unit = match.repl_unit
|
||||
placeholder = _PlaceholderInfo(
|
||||
modality=match.modality,
|
||||
start_idx=match.start_idx,
|
||||
unit=unit,
|
||||
unit_count=(match.end_idx - match.start_idx) // len(unit),
|
||||
)
|
||||
|
||||
if placeholder.unit_count >= min_unit_count:
|
||||
yield placeholder
|
||||
|
||||
|
||||
class MultiModalProcessor(ABC):
|
||||
"""
|
||||
Helper class to process multi-modal inputs to be used in vLLM.
|
||||
"""
|
||||
@ -546,6 +544,12 @@ class MultiModalProcessor:
|
||||
self.ctx = ctx
|
||||
self.metadata = metadata
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -562,13 +566,13 @@ class MultiModalProcessor:
|
||||
# To avoid false positives from multi-input when detecting
|
||||
# whether placeholder tokens have been inserted, in case
|
||||
# the target sequence is a subset of the replacement tokens
|
||||
min_placeholder_count: int = 16,
|
||||
min_unit_count: int = 16,
|
||||
) -> list[_PlaceholderInfo]:
|
||||
return list(
|
||||
iter_placeholders(
|
||||
all_prompt_repls,
|
||||
new_token_ids,
|
||||
min_placeholder_count=min_placeholder_count,
|
||||
min_unit_count=min_unit_count,
|
||||
))
|
||||
|
||||
def _apply_hf_processor(
|
||||
@ -577,19 +581,49 @@ class MultiModalProcessor:
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
hf_processor = self._get_hf_processor()
|
||||
|
||||
return hf_processor(
|
||||
text=prompt, # type: ignore
|
||||
**mm_data,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
processor_data = dict[str, Any]()
|
||||
passthrough_data = dict[str, Any]()
|
||||
for k, v in mm_data.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
if k in ("image", "video", "audio"):
|
||||
if isinstance(v, torch.Tensor) and v.ndim == 3:
|
||||
# Pass through embedding inputs (single)
|
||||
passthrough_data[f"{k}_embeds"] = [v]
|
||||
elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
else:
|
||||
# Map keys to plural form, e.g.: image -> images
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
processor_data[k] = v
|
||||
|
||||
try:
|
||||
hf_inputs = hf_processor(
|
||||
text=prompt, # type: ignore
|
||||
**processor_data,
|
||||
**mm_processor_kwargs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
except Exception as exc:
|
||||
data = dict(text=prompt, **processor_data)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to apply {type(hf_processor).__name__} "
|
||||
f"on data={data} with kwargs={mm_processor_kwargs}") from exc
|
||||
|
||||
hf_inputs.update(passthrough_data)
|
||||
|
||||
return hf_inputs
|
||||
|
||||
def _bind_prompt_replacements(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> list[_BoundPromptReplacement[Any]]:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
return [
|
||||
prompt_repl.bind(modality, tokenizer)
|
||||
@ -604,7 +638,7 @@ class MultiModalProcessor:
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
mm_items = to_multi_format(mm_data)
|
||||
token_matches = find_token_matches(token_ids, prompt_repls)
|
||||
@ -620,7 +654,7 @@ class MultiModalProcessor:
|
||||
# of the search text in the prompt, we instead perform string
|
||||
# replacement on the decoded token IDs, then encode them back.
|
||||
if all(
|
||||
len(matches) >= len(mm_data[modality])
|
||||
len(matches) >= len(mm_items[modality])
|
||||
for modality, matches in full_groupby_modality(token_matches)
|
||||
): # yapf: disable
|
||||
token_ids = replace_token_matches(
|
||||
@ -648,15 +682,6 @@ class MultiModalProcessor:
|
||||
|
||||
placeholders = self._find_placeholders(matched_repls, token_ids)
|
||||
|
||||
# Sanity check
|
||||
assert len(placeholders) == len(matched_repls), dict(
|
||||
# Log this information for easier debugging
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
placeholders=placeholders,
|
||||
matched_repls=matched_repls,
|
||||
)
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def apply(
|
||||
@ -678,7 +703,7 @@ class MultiModalProcessor:
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
tokenizer = self.ctx.tokenizer
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
|
||||
mm_processor_kwargs)
|
||||
@ -717,3 +742,59 @@ class MultiModalProcessor:
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_dummy_mm_kwargs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
Build the input that corresponds to `mm_max_tokens` in
|
||||
:meth:`get_dummy_data`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_max_tokens: Mapping[str, int],
|
||||
) -> DummyData:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
mm_placeholders = dict[str, _PlaceholderInfo]()
|
||||
offset = 0
|
||||
|
||||
for modality, max_tokens in mm_max_tokens.items():
|
||||
if max_tokens == 0:
|
||||
continue
|
||||
|
||||
metadata = self.metadata[modality]
|
||||
repl = metadata.prompt_repls[0].bind(modality, tokenizer)
|
||||
repl_token_ids = repl.repl_unit.token_ids
|
||||
|
||||
placeholders = _PlaceholderInfo(
|
||||
modality=modality,
|
||||
start_idx=offset,
|
||||
unit=repl_token_ids,
|
||||
unit_count=max_tokens // len(repl_token_ids),
|
||||
)
|
||||
|
||||
mm_placeholders[modality] = placeholders
|
||||
offset += placeholders.length
|
||||
|
||||
prompt_token_ids = flatten_2d_lists(
|
||||
[p.unit * p.unit_count for p in mm_placeholders.values()])
|
||||
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||
multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
|
||||
multi_modal_placeholders={
|
||||
modality: [p.to_range()]
|
||||
for modality, p in mm_placeholders.items()
|
||||
},
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ from .audio import AudioPlugin
|
||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||
from .image import ImagePlugin
|
||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||
from .processing import MultiModalProcessor
|
||||
from .processing import MultiModalProcessingMetadata, MultiModalProcessor
|
||||
from .video import VideoPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -200,6 +200,27 @@ class MultiModalRegistry:
|
||||
"""
|
||||
return self.register_max_multimodal_tokens("image", max_mm_tokens)
|
||||
|
||||
def get_max_tokens_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens from each modality
|
||||
for profiling the memory usage of a model.
|
||||
|
||||
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
limits_per_plugin = self._limits_by_model[model_config]
|
||||
|
||||
return {
|
||||
key: (limits_per_plugin[key] *
|
||||
plugin.get_max_multimodal_tokens(model_config))
|
||||
for key, plugin in self._plugins.items()
|
||||
}
|
||||
|
||||
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum number of multi-modal tokens
|
||||
@ -210,11 +231,7 @@ class MultiModalRegistry:
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
limits_per_plugin = self._limits_by_model[model_config]
|
||||
|
||||
return sum((limits_per_plugin[key] *
|
||||
plugin.get_max_multimodal_tokens(model_config))
|
||||
for key, plugin in self._plugins.items())
|
||||
return sum(self.get_max_tokens_by_modality(model_config).values())
|
||||
|
||||
def init_mm_limits_per_prompt(
|
||||
self,
|
||||
@ -270,7 +287,8 @@ class MultiModalRegistry:
|
||||
factory: MultiModalProcessorFactory,
|
||||
):
|
||||
"""
|
||||
Register a multi-modal processor to a model class.
|
||||
Register a multi-modal processor to a model class. The processor
|
||||
is constructed lazily, hence a factory method should be passed.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
@ -293,6 +311,41 @@ class MultiModalRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def register_processor_by_metadata(
|
||||
self,
|
||||
metadata_factory: Callable[[InputProcessingContext],
|
||||
MultiModalProcessingMetadata],
|
||||
get_dummy_mm_kwargs: Callable[
|
||||
[InputProcessingContext, Mapping[str, int]], MultiModalKwargs],
|
||||
):
|
||||
"""
|
||||
Convenience method to register a multi-modal processor to a model class
|
||||
according to a function that constructs its metadata.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
|
||||
See also:
|
||||
- :ref:`input_processing_pipeline`
|
||||
- :ref:`enabling_multimodal_inputs`
|
||||
"""
|
||||
|
||||
class ConcreteMultiModalProcessor(MultiModalProcessor):
|
||||
|
||||
def _get_dummy_mm_kwargs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
return get_dummy_mm_kwargs(self.ctx, mm_counts)
|
||||
|
||||
def factory(ctx: InputProcessingContext):
|
||||
return ConcreteMultiModalProcessor(
|
||||
ctx=ctx,
|
||||
metadata=metadata_factory(ctx),
|
||||
)
|
||||
|
||||
return self.register_processor(factory)
|
||||
|
||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Test whether a multi-modal processor is defined for a specific model.
|
||||
|
||||
@ -12,6 +12,7 @@ class MMInputMapper:
|
||||
model_config: ModelConfig,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
|
||||
model_config)
|
||||
|
||||
@ -7,7 +7,8 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -101,10 +102,15 @@ class Processor:
|
||||
self.generation_config_fields, eos_token_id)
|
||||
|
||||
# Preprocess multi-modal data
|
||||
mm_inputs = self.mm_input_mapper.process_inputs(
|
||||
decoder_inputs.multi_modal_data,
|
||||
decoder_inputs.mm_processor_kwargs) if len(
|
||||
decoder_inputs.multi_modal_data) > 0 else None
|
||||
if len(decoder_inputs.multi_modal_data) == 0:
|
||||
mm_inputs = None
|
||||
elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
|
||||
mm_inputs = [decoder_inputs.multi_modal_data]
|
||||
else:
|
||||
mm_inputs = self.mm_input_mapper.process_inputs(
|
||||
decoder_inputs.multi_modal_data,
|
||||
decoder_inputs.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
# Make Request for Detokenizer.
|
||||
detokenizer_request = DetokenizerRequest(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user