mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
300acb8347
commit
eed11ebee9
@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
|
||||
from ....utils import build_model_context
|
||||
|
||||
|
||||
# Fixtures lazy import to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
def processor_for_llava_next():
|
||||
from vllm.model_executor.models.llava_next import (
|
||||
LlavaNextMultiModalProcessor)
|
||||
return LlavaNextMultiModalProcessor
|
||||
|
||||
|
||||
# FIXME: image_size [(198, 176), (176, 198)]
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
|
||||
(488, 183)])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_prompt_replacements(
|
||||
processor_for_llava_next,
|
||||
model_id: str,
|
||||
image_size: tuple[int, int],
|
||||
num_imgs: int,
|
||||
):
|
||||
"""
|
||||
Ensure LlavaNextMultiModalProcessor handles prompt replacement properly.
|
||||
"""
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
prompt = "<image>" * num_imgs
|
||||
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
|
||||
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processor = processor_for_llava_next(ctx)
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
assert len(image_placeholders) == num_imgs
|
||||
|
||||
first_placeholder = image_placeholders[0]
|
||||
|
||||
# NOTE: There is a BOS token
|
||||
assert first_placeholder["offset"] == 1
|
||||
assert first_placeholder["length"] == (
|
||||
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
|
||||
@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
|
||||
from ....utils import build_model_context
|
||||
|
||||
|
||||
# Fixtures lazy import to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
def processor_for_llava_onevision():
|
||||
from vllm.model_executor.models.llava_onevision import (
|
||||
LlavaOnevisionMultiModalProcessor)
|
||||
return LlavaOnevisionMultiModalProcessor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id",
|
||||
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
|
||||
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
|
||||
(488, 183), (198, 176), (176, 198)])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_prompt_replacements(
|
||||
processor_for_llava_onevision,
|
||||
model_id: str,
|
||||
image_size: tuple[int, int],
|
||||
num_imgs: int,
|
||||
):
|
||||
"""
|
||||
Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement
|
||||
properly.
|
||||
"""
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
prompt = "<image>" * num_imgs
|
||||
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
|
||||
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processor = processor_for_llava_onevision(ctx)
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
assert len(image_placeholders) == num_imgs
|
||||
|
||||
first_placeholder = image_placeholders[0]
|
||||
|
||||
# NOTE: There is a BOS token
|
||||
assert first_placeholder["offset"] == 0
|
||||
assert first_placeholder["length"] == len(
|
||||
processed_inputs["prompt_token_ids"]) // num_imgs
|
||||
@ -1,6 +1,4 @@
|
||||
"""Tests for phi3v's multimodal preprocessing kwargs."""
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -10,8 +8,6 @@ from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||
from .....conftest import _ImageAssets
|
||||
from ....utils import build_model_context
|
||||
|
||||
models = ["microsoft/Phi-3.5-vision-instruct"]
|
||||
|
||||
|
||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
@ -20,40 +16,40 @@ def processor_for_phi3v():
|
||||
return Phi3VMultiModalProcessor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"num_crops,expected_toks_per_img",
|
||||
("mm_processor_kwargs", "expected_toks_per_img"),
|
||||
[
|
||||
(4, 757),
|
||||
(16, 1921),
|
||||
({"num_crops": 4}, 757),
|
||||
({"num_crops": 16}, 1921),
|
||||
# the default num_crops of phi-3.5-vision is 4
|
||||
(None, 757),
|
||||
({}, 757),
|
||||
])
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
|
||||
model: str, num_crops: Optional[int],
|
||||
expected_toks_per_img: int, num_imgs: int):
|
||||
def test_processor_override(
|
||||
processor_for_phi3v,
|
||||
image_assets: _ImageAssets,
|
||||
model_id: str,
|
||||
mm_processor_kwargs: dict[str, int],
|
||||
expected_toks_per_img: int,
|
||||
num_imgs: int,
|
||||
):
|
||||
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
|
||||
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
||||
images = [image_assets[0].pil_image] * num_imgs
|
||||
|
||||
mm_data = {"image": images}
|
||||
mm_processor_kwargs = {}
|
||||
if num_crops is not None:
|
||||
mm_processor_kwargs = {"num_crops": num_crops}
|
||||
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
||||
|
||||
processor = processor_for_phi3v(ctx)
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -8,56 +6,45 @@ from vllm.inputs import InputProcessingContext
|
||||
from .....conftest import _ImageAssets
|
||||
from ....utils import build_model_context
|
||||
|
||||
MODEL = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
MIN_PIXELS = "min_pixels"
|
||||
MAX_PIXELS = "max_pixels"
|
||||
|
||||
|
||||
# Fixtures lazy import to avoid initializing CUDA during test collection
|
||||
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
|
||||
# input mappers.
|
||||
@pytest.fixture()
|
||||
def processor_for_qwen2_vl():
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
|
||||
return Qwen2VLMultiModalProcessor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
|
||||
("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [
|
||||
({}, 1426, (5704, 1176)),
|
||||
({
|
||||
MIN_PIXELS: 64**2,
|
||||
MAX_PIXELS: 512**2
|
||||
}, 330, (1320, 1176)),
|
||||
({"min_pixels": 64**2, "max_pixels": 512**2}, 330, (1320, 1176)),
|
||||
])
|
||||
@pytest.mark.parametrize("model", [MODEL])
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_override(
|
||||
processor_for_qwen2_vl,
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
model_id: str,
|
||||
mm_processor_kwargs: dict[str, object],
|
||||
expected_toks_per_img: int,
|
||||
expected_pixels_shape: Tuple[int, int],
|
||||
expected_pixels_shape: tuple[int, int],
|
||||
num_imgs: int,
|
||||
):
|
||||
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
|
||||
images = [image_assets[0].pil_image] * num_imgs
|
||||
|
||||
mm_data = {"image": images}
|
||||
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
||||
|
||||
processor = processor_for_qwen2_vl(ctx)
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
@ -274,10 +274,8 @@ VLM_TEST_SETTINGS = {
|
||||
),
|
||||
limit_mm_per_prompt={"image": 4},
|
||||
)],
|
||||
# Llava-next tests fixed sizes & the default size factors
|
||||
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
||||
),
|
||||
"llava_one_vision": VLMTestInfo(
|
||||
"llava_onevision": VLMTestInfo(
|
||||
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
|
||||
test_type=VLMTestType.CUSTOM_INPUTS,
|
||||
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
@ -288,8 +286,6 @@ VLM_TEST_SETTINGS = {
|
||||
),
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
|
||||
# Llava-one-vision tests fixed sizes & the default size factors
|
||||
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
||||
custom_test_opts=[CustomTestOptions(
|
||||
inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs(
|
||||
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
@ -306,7 +302,6 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
|
||||
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
||||
),
|
||||
"mantis": VLMTestInfo(
|
||||
models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
|
||||
@ -431,7 +426,7 @@ VLM_TEST_SETTINGS = {
|
||||
) for inp in custom_inputs.different_patch_input_cases_internvl()
|
||||
],
|
||||
),
|
||||
"llava_one_vision-multiple-images": VLMTestInfo(
|
||||
"llava_onevision-multiple-images": VLMTestInfo(
|
||||
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
|
||||
test_type=VLMTestType.CUSTOM_INPUTS,
|
||||
max_model_len=16384,
|
||||
|
||||
@ -427,130 +427,3 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
def run_chunked_prefill_test(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Compare inference result between
|
||||
chunked prefill disabled and chunked prefill enabled
|
||||
"""
|
||||
|
||||
# NOTE:
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
task="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=4,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={
|
||||
"image": mm_limit,
|
||||
"video": mm_limit
|
||||
},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend
|
||||
) as vllm_model:
|
||||
|
||||
outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images or None,
|
||||
videos=videos or None)
|
||||
for prompts, images, videos in inputs
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
task="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=4,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={
|
||||
"image": mm_limit,
|
||||
"video": mm_limit
|
||||
},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_chunked_prefill=True,
|
||||
# should be small enough to ensure prefilling is chunked
|
||||
max_num_batched_tokens=32,
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 16 * 28 * 28,
|
||||
}) as vllm_model_chunked:
|
||||
outputs_per_case_chunked = [
|
||||
vllm_model_chunked.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images or None,
|
||||
videos=videos or None) for prompts, images, videos in inputs
|
||||
]
|
||||
|
||||
for outputs, \
|
||||
outputs_chunked \
|
||||
in zip(outputs_per_case,
|
||||
outputs_per_case_chunked):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=outputs,
|
||||
outputs_1_lst=outputs_chunked,
|
||||
name_0="non_chunked",
|
||||
name_1="chunked",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [1])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
"""
|
||||
Test Qwen2-VL's chunked prefill with M-RoPE
|
||||
"""
|
||||
prompts = [
|
||||
qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt)
|
||||
for prompt in example_prompts[:1]
|
||||
]
|
||||
|
||||
# 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
|
||||
# so an image is included in the inputs
|
||||
# 2. however, Qwen2-VL currently won't work properly
|
||||
# when chunked prefill is enabled and there are some multi-modal inputs,
|
||||
# here use a hacky way: provide a **zero-length** image to make it happy
|
||||
#
|
||||
# and finally we achieved:
|
||||
# (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests
|
||||
zero_len_image = {
|
||||
"image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)),
|
||||
"image_grid_thw": torch.tensor([[0, 0, 0]])
|
||||
}
|
||||
images = [zero_len_image] * len(prompts)
|
||||
|
||||
inputs_per_case: List[Tuple[List[str], PromptImageInput,
|
||||
PromptVideoInput]] = [
|
||||
(prompts, images, []),
|
||||
]
|
||||
|
||||
run_chunked_prefill_test(
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -11,8 +11,8 @@ from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
||||
_PlaceholderInfo, find_text_matches,
|
||||
find_token_matches, iter_placeholders,
|
||||
_PlaceholderInfo, find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
@ -314,21 +314,27 @@ def test_find_replace_text(
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_text_matches(prompt, prompt_repls)
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_text_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
|
||||
result = replace_text_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("matches:", matches)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
@ -380,21 +386,27 @@ def test_find_replace_tokens(
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_token_matches(prompt, prompt_repls)
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_token_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
|
||||
result = replace_token_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("matches:", matches)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
@ -417,58 +429,76 @@ def test_find_replace_tokens(
|
||||
[
|
||||
(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=6,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=5,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=7,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=5,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=7,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
}
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=3,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=6,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=3,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_iter_placeholders(
|
||||
def test_find_mm_placeholders(
|
||||
repl_by_key,
|
||||
prompt,
|
||||
expected,
|
||||
@ -476,19 +506,18 @@ def test_iter_placeholders(
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, [], repl).bind(mock_tokenizer)
|
||||
mm_prompt_repls = {
|
||||
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
|
||||
for key, repl in repl_by_key.items()
|
||||
]
|
||||
}
|
||||
|
||||
result = list(
|
||||
iter_placeholders(
|
||||
prompt_repls,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3
|
||||
for key in repl_by_key},
|
||||
))
|
||||
result = find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@ -694,7 +723,10 @@ def _test_processing_cache_correctness(
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text
|
||||
prompt = baseline_processor._get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
).prompt_text
|
||||
|
||||
# Drop unnecessary keys and test single -> multi conversion
|
||||
if rng.rand() < simplify_rate:
|
||||
@ -728,6 +760,8 @@ def _test_processing_cache_correctness(
|
||||
("adept/fuyu-8b", {"image": False}),
|
||||
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
||||
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
|
||||
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
|
||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
||||
("mistral-community/pixtral-12b", {"image": True}),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
||||
|
||||
@ -456,7 +456,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@ -488,8 +488,9 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
|
||||
@ -405,7 +405,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> Blip2Processor:
|
||||
@ -457,8 +457,9 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
|
||||
@ -57,7 +57,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
processor = self._get_hf_processor()
|
||||
return processor.image_seq_length
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> ChameleonProcessor:
|
||||
@ -90,8 +90,9 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
config = self.ctx.get_hf_config(ChameleonConfig)
|
||||
|
||||
@ -164,15 +164,18 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_clip_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
def get_patch_size(self) -> int:
|
||||
return self.vision_config.patch_size
|
||||
|
||||
def get_patch_grid_length(self) -> int:
|
||||
return get_clip_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
@ -96,7 +96,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
max_ncols, max_nrows = self._get_image_feature_grid_size(
|
||||
@ -208,8 +208,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
@ -25,11 +25,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
@ -39,7 +37,7 @@ from .pixtral import (PixtralHFVisionModel,
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import vision_encoder_info
|
||||
from .vision import BaseVisionLanguageMultiModalProcessor
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@ -100,19 +98,7 @@ class LlavaLikeConfig(Protocol):
|
||||
vision_feature_layer: Final[Union[int, List[int]]]
|
||||
|
||||
|
||||
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
@ -121,6 +107,19 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
@ -142,19 +141,6 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self._vision_encoder_info.get_max_image_tokens(),
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
@ -163,8 +149,9 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_image_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -709,7 +696,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
mantis_mm_repls = self._bind_and_group_repls([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * num_image_tokens,
|
||||
@ -719,7 +706,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
|
||||
result["prompt_token_ids"],
|
||||
mantis_repls,
|
||||
mantis_mm_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
@ -728,15 +715,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
|
||||
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
|
||||
|
||||
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
assert len(all_placeholders) == mm_item_counts.get("image", 0)
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
orig_repls,
|
||||
prompt_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [item.to_range() for item in items]
|
||||
for modality, items in full_groupby_modality(all_placeholders)
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
mm_placeholder_ranges = {
|
||||
modality: [item.to_range() for item in placeholders]
|
||||
for modality, placeholders in mm_placeholders.items()
|
||||
}
|
||||
|
||||
return MultiModalInputsV2(
|
||||
@ -744,7 +735,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -67,9 +67,6 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
def _get_hf_processor(self) -> LlavaNextProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@ -81,6 +78,9 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
largest_feature_size, _ = self._get_pinpoint_with_most_features()
|
||||
return largest_feature_size
|
||||
@ -97,20 +97,20 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_encoder_info = self._vision_encoder_info
|
||||
|
||||
base_feature_size = self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
num_patches = self._vision_encoder_info.get_num_patches()
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(image_height, image_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=self._vision_encoder_info.get_image_size(),
|
||||
patch_size=vision_encoder_info.get_image_size(),
|
||||
)
|
||||
|
||||
(
|
||||
@ -119,7 +119,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
) = self._get_num_unpadded_features(
|
||||
original_height=image_height,
|
||||
original_width=image_width,
|
||||
npatches=num_patches,
|
||||
npatches=vision_encoder_info.get_patch_grid_length(),
|
||||
num_patch_height=num_patch_height,
|
||||
num_patch_width=num_patch_width,
|
||||
)
|
||||
@ -155,6 +155,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
|
||||
|
||||
@ -3,38 +3,32 @@ from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
|
||||
SiglipVisionConfig)
|
||||
from transformers import (BatchFeature, LlavaNextVideoConfig,
|
||||
LlavaNextVideoProcessor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 32
|
||||
_MAX_NUM_VIDEOS = 1
|
||||
from .vision import BaseVisionLanguageMultiModalProcessor
|
||||
|
||||
|
||||
class LlavaNextVideoPixelInputs(TypedDict):
|
||||
@ -50,144 +44,149 @@ class LlavaNextVideoPixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
def get_llava_next_video_frame_feature_size(
|
||||
hf_config: LlavaNextVideoConfig) -> int:
|
||||
# Support both CLIPVisionConfig and SiglipVisionConfig
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
spatial_pool_stride = hf_config.spatial_pool_stride
|
||||
class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
|
||||
return int((image_size / patch_size / spatial_pool_stride)**2)
|
||||
def _get_hf_config(self) -> LlavaNextVideoConfig:
|
||||
return self.ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaNextVideoProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
|
||||
|
||||
def _get_max_llm_tokens(ctx: InputContext) -> int:
|
||||
"""
|
||||
Calculated from the maximum video frames under the context length
|
||||
constraints of the language model.
|
||||
"""
|
||||
hf_text_config = ctx.model_config.hf_text_config
|
||||
model_config = ctx.model_config
|
||||
max_tokens = model_config.max_model_len
|
||||
rope_scaling = model_config.rope_scaling
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"video": 1}
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = hf_text_config.rope_scaling["factor"]
|
||||
else:
|
||||
rope_scaling_factor = 1
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
num_frames = self._get_dummy_num_frames(seq_len)
|
||||
max_video_tokens = self._get_max_video_tokens(num_frames)
|
||||
|
||||
max_tokens *= rope_scaling_factor
|
||||
return {"video": max_video_tokens}
|
||||
|
||||
return max_tokens
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
|
||||
|
||||
def _get_num_frame_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
spatial_pool_stride = hf_config.spatial_pool_stride
|
||||
|
||||
def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
|
||||
# Currently set to 32 frames
|
||||
# TODO: max_tokens = _get_max_llm_tokens(ctx)
|
||||
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
|
||||
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
|
||||
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
|
||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||
|
||||
return pooled_grid_length * pooled_grid_length
|
||||
|
||||
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# TODO: support multiple videos
|
||||
num_videos = mm_counts["video"]
|
||||
if num_videos != _MAX_NUM_VIDEOS:
|
||||
raise NotImplementedError(
|
||||
f"Only {_MAX_NUM_VIDEOS} videos are supported")
|
||||
|
||||
# TODO: support configuring the number of frames
|
||||
frames_per_video = _MAX_FRAMES_PER_VIDEO
|
||||
# num_images = num_videos * frames_per_video
|
||||
|
||||
# fills the sequence with as longer video data as possible
|
||||
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
|
||||
video_feature_size = frames_per_video * tokens_per_frame
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_videos,
|
||||
image_token_id=hf_config.video_token_index,
|
||||
image_feature_size_override=video_feature_size,
|
||||
mm_key="video",
|
||||
def _get_num_video_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
num_frames: int,
|
||||
) -> int:
|
||||
num_frame_tokens = self._get_num_frame_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
pil_frame = dummy_image_for_clip(vision_config, num_images=1)
|
||||
np_frame = np.array(pil_frame["image"])
|
||||
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
|
||||
mm_data = {"video": mm_data_per_video}
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_videos,
|
||||
image_token_id=hf_config.video_token_index,
|
||||
image_feature_size_override=video_feature_size,
|
||||
mm_key="video",
|
||||
return num_frame_tokens * num_frames
|
||||
|
||||
def _get_max_video_tokens(self, num_frames: int) -> int:
|
||||
return self._get_num_video_tokens(image_width=999999,
|
||||
image_height=999999,
|
||||
num_frames=num_frames)
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
num_frames = 0
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
|
||||
if self._get_max_video_tokens(next_num_frames) > max_tokens:
|
||||
break
|
||||
|
||||
num_frames = next_num_frames
|
||||
|
||||
return num_frames
|
||||
|
||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
|
||||
def _get_video_token(self) -> str:
|
||||
return self._get_hf_processor().video_token
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
video_token_id = hf_config.video_token_index
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
videos = mm_items.get_items(
|
||||
"video", (VideoEmbeddingItems, VideoProcessorItems))
|
||||
|
||||
if isinstance(videos, VideoEmbeddingItems):
|
||||
num_video_tokens = videos.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = videos.get_frame_size(item_idx)
|
||||
num_video_tokens = self._get_num_video_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
num_frames=videos.get_num_frames(item_idx),
|
||||
)
|
||||
|
||||
return [video_token_id] * num_video_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="video",
|
||||
target=[video_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
video_token = self._get_video_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
|
||||
np_frame = np.array(pil_frame["image"])
|
||||
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
|
||||
mm_data = {"video": mm_data_per_video}
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava_next_video(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "video" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
if "multi_modal_placeholders" in inputs and "video" in inputs[
|
||||
"multi_modal_placeholders"]:
|
||||
# The inputs already have placeholders.
|
||||
return inputs
|
||||
|
||||
video_data = multi_modal_data["video"]
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(video_data, np.ndarray):
|
||||
# Supports both CLIP and Siglip
|
||||
num_frames = video_data.shape[0]
|
||||
frame_feature_size = \
|
||||
get_llava_next_video_frame_feature_size(hf_config)
|
||||
video_feature_size = num_frames * frame_feature_size
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"video": ranges})
|
||||
|
||||
elif is_list_of(video_data, np.ndarray):
|
||||
raise NotImplementedError(
|
||||
"Processing multiple videos is not supported")
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
# adopted from transformers modeling_llava_next_video.py
|
||||
class LlavaNextVideoPooler(nn.Module):
|
||||
@ -246,11 +245,7 @@ class LlavaNextMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("video")
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"video", get_max_llava_next_video_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
|
||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
|
||||
@ -3,47 +3,36 @@ from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import (CLIPVisionConfig, LlavaOnevisionConfig,
|
||||
SiglipVisionConfig)
|
||||
from transformers import (BatchFeature, LlavaOnevisionConfig,
|
||||
LlavaOnevisionProcessor)
|
||||
from transformers.models.llava_onevision.modeling_llava_onevision import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
|
||||
VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
|
||||
dummy_video_for_clip, get_clip_image_feature_size,
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
|
||||
dummy_video_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
from .llava_next import LlavaNextMultiModalProcessor
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
|
||||
|
||||
class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
@ -92,286 +81,251 @@ LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
|
||||
LlavaOnevisionVideoPixelInputs]
|
||||
|
||||
|
||||
def _get_llava_onevision_image_unppaded_feature_size(height, width, patches,
|
||||
scale_height,
|
||||
scale_width):
|
||||
current_height = patches * scale_height
|
||||
current_width = patches * scale_width
|
||||
class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
|
||||
original_aspect_ratio = width / height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(height * (current_width / width))
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= padding * 2
|
||||
else:
|
||||
new_width = int(width * (current_height / height))
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= padding * 2
|
||||
def _get_hf_config(self) -> LlavaOnevisionConfig:
|
||||
return self.ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
def _get_hf_processor(self) -> LlavaOnevisionProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
|
||||
|
||||
ratio = math.sqrt(current_height * current_width / (9 * patches**2))
|
||||
if ratio > 1.1:
|
||||
unpadded_features = int(current_height // ratio) * int(
|
||||
current_width // ratio)
|
||||
newline_features = int(current_height // ratio)
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
max_image_tokens = self._get_max_image_tokens()
|
||||
|
||||
num_frames = self._get_dummy_num_frames(seq_len)
|
||||
max_video_tokens = self._get_max_video_tokens(num_frames)
|
||||
|
||||
def get_llava_onevision_image_feature_size(
|
||||
hf_config: LlavaOnevisionConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
vision_config = hf_config.vision_config
|
||||
return {
|
||||
"image": max_image_tokens,
|
||||
"video": max_video_tokens,
|
||||
}
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_patches = get_clip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
base_feature_size = get_clip_image_feature_size(vision_config)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
num_patches = get_siglip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
|
||||
def _get_num_unpadded_features(
|
||||
self,
|
||||
*,
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(original_height *
|
||||
(current_width / original_width))
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= padding * 2
|
||||
else:
|
||||
new_width = int(original_width *
|
||||
(current_height / original_height))
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= padding * 2
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
||||
ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
|
||||
if ratio > 1.1:
|
||||
unpadded_features = int(current_height // ratio) * int(
|
||||
current_width // ratio)
|
||||
newline_features = int(current_height // ratio)
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
def _get_num_frame_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
|
||||
|
||||
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
|
||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||
|
||||
return pooled_grid_length * pooled_grid_length
|
||||
|
||||
def _get_num_video_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
num_frames: int,
|
||||
) -> int:
|
||||
num_frame_tokens = self._get_num_frame_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
base_feature_size = get_siglip_image_feature_size(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
strategy = hf_config.vision_feature_select_strategy
|
||||
if strategy == "default":
|
||||
base_feature_size -= 1
|
||||
elif strategy == "full":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
return num_frame_tokens * num_frames + 1 # Newline token
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
)
|
||||
def _get_max_video_tokens(self, num_frames: int) -> int:
|
||||
return self._get_num_video_tokens(image_width=999999,
|
||||
image_height=999999,
|
||||
num_frames=num_frames)
|
||||
|
||||
(
|
||||
unpadded_feature_size,
|
||||
newline_feature_size,
|
||||
) = _get_llava_onevision_image_unppaded_feature_size(
|
||||
input_height, input_width, num_patches, num_patch_height,
|
||||
num_patch_width)
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
num_frames = 0
|
||||
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
|
||||
if self._get_max_video_tokens(next_num_frames) > max_tokens:
|
||||
break
|
||||
|
||||
def get_max_llava_onevision_image_tokens(ctx: InputContext):
|
||||
return get_llava_onevision_image_feature_size(
|
||||
ctx.get_hf_config(LlavaOnevisionConfig),
|
||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
num_frames = next_num_frames
|
||||
|
||||
return num_frames
|
||||
|
||||
def get_llava_onevision_video_frame_feature_size(
|
||||
hf_config: LlavaOnevisionConfig) -> int:
|
||||
# Support both CLIPVisionConfig and SiglipVisionConfig
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
spatial_pool_stride = hf_config.spatial_pool_stride if hasattr(
|
||||
hf_config, "spatial_pool_stride") else 2
|
||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||
|
||||
height = width = image_size // patch_size
|
||||
return math.ceil(height / spatial_pool_stride) * math.ceil(
|
||||
width / spatial_pool_stride)
|
||||
max_image_tokens = self._get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
def get_llava_onevision_video_tokens(ctx: InputContext,
|
||||
num_frames: int) -> int:
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
def _get_video_token(self) -> str:
|
||||
return self._get_hf_processor().video_token
|
||||
|
||||
# TODO: support configuring (not supported by HF right now)
|
||||
num_token_image_newline = 1
|
||||
tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
|
||||
video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
mm_data = dict(mm_data)
|
||||
videos = mm_data.pop("videos", [])
|
||||
assert isinstance(videos, list)
|
||||
|
||||
return video_feature_size
|
||||
if not videos:
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
video_token = self._get_video_token()
|
||||
|
||||
def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
|
||||
return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
|
||||
def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
num_videos = mm_counts["video"]
|
||||
|
||||
# TODO: support configuring the number of frames
|
||||
num_frames = _MAX_FRAMES_PER_VIDEO
|
||||
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_videos,
|
||||
image_token_id=hf_config.video_token_index,
|
||||
image_feature_size_override=video_feature_size,
|
||||
mm_key="video")
|
||||
|
||||
mm_data = dummy_video_for_clip(vision_config,
|
||||
num_frames=num_frames,
|
||||
num_videos=num_videos)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_videos,
|
||||
image_token_id=hf_config.video_token_index,
|
||||
image_feature_size_override=video_feature_size,
|
||||
mm_key="video")
|
||||
|
||||
mm_data = dummy_video_for_siglip(vision_config,
|
||||
num_frames=num_frames,
|
||||
num_videos=num_videos)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_when_multimodal_input_image(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
|
||||
image_feature_size = get_llava_onevision_image_feature_size(
|
||||
hf_config,
|
||||
input_height=height,
|
||||
input_width=width,
|
||||
# LLaVA-OneVision processor doesn't support multiple videos
|
||||
# with different sizes when converting back to tensors
|
||||
text_image_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = [
|
||||
get_llava_onevision_image_feature_size(hf_config,
|
||||
input_height=img.height,
|
||||
input_width=img.width)
|
||||
for img in image_data
|
||||
|
||||
pixel_values_videos = []
|
||||
for video in videos:
|
||||
item_processor_data = dict(prompt=video_token, videos=video)
|
||||
|
||||
item_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=item_processor_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
pixel_values_videos.append(
|
||||
item_outputs.pop("pixel_values_videos")[0])
|
||||
|
||||
combined_outputs = dict(
|
||||
**text_image_outputs,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
image_repls = super()._get_prompt_replacements(
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
out_mm_kwargs=out_mm_kwargs,
|
||||
)
|
||||
|
||||
hf_config = self._get_hf_config()
|
||||
video_token_id = hf_config.video_token_index
|
||||
|
||||
def get_video_replacement(item_idx: int):
|
||||
videos = mm_items.get_items(
|
||||
"video", (VideoEmbeddingItems, VideoProcessorItems))
|
||||
|
||||
if isinstance(videos, VideoEmbeddingItems):
|
||||
num_video_tokens = videos.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = videos.get_frame_size(item_idx)
|
||||
num_video_tokens = self._get_num_video_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
num_frames=videos.get_num_frames(item_idx),
|
||||
)
|
||||
|
||||
return [video_token_id] * num_video_tokens
|
||||
|
||||
return image_repls + [
|
||||
PromptReplacement(
|
||||
modality="video",
|
||||
target=[video_token_id],
|
||||
replacement=get_video_replacement,
|
||||
),
|
||||
]
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
image_feature_size = [item.shape[1] for item in image_data]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
image_token = self._get_image_token()
|
||||
video_token = self._get_video_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images + video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "video" not in multi_modal_data:
|
||||
return inputs
|
||||
video_data = multi_modal_data["video"]
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
|
||||
if isinstance(video_data, np.ndarray):
|
||||
# Supports both CLIP and Siglip
|
||||
num_frames = video_data.shape[0]
|
||||
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"video": ranges})
|
||||
|
||||
elif is_list_of(video_data, np.ndarray):
|
||||
video_feature_size = []
|
||||
for video in video_data:
|
||||
num_frames = video.shape[0]
|
||||
video_feature_size.append(
|
||||
get_llava_onevision_video_tokens(ctx, num_frames))
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"video": ranges})
|
||||
else:
|
||||
raise TypeError(f"Invalid video type: {type(video_data)}")
|
||||
|
||||
msg = f"Unsupported video type: {type(video_data)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava_onevision(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or ("video" not in multi_modal_data
|
||||
and "image" not in multi_modal_data):
|
||||
return inputs
|
||||
if "image" in multi_modal_data:
|
||||
return input_processor_when_multimodal_input_image(ctx, inputs)
|
||||
if "video" in multi_modal_data:
|
||||
return input_processor_when_multimodal_input_video(ctx, inputs)
|
||||
|
||||
msg = "Unsupported multi data type"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
@ -394,14 +348,7 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("video")
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"image", get_max_llava_onevision_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"video", get_max_llava_onevision_video_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
|
||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
|
||||
@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
height=image_height,
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
max_image_tokens = self._get_num_image_tokens(
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
|
||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||
token_ids=token_ids,
|
||||
prompt_repls=prompt_repls,
|
||||
mm_prompt_repls=mm_prompt_repls,
|
||||
mm_item_counts=mm_item_counts,
|
||||
)
|
||||
|
||||
@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
if text.startswith("<s> <|image|>"):
|
||||
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
|
||||
token_ids = [token_ids[0], *token_ids[2:]]
|
||||
placeholders = [
|
||||
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
|
||||
for p in placeholders
|
||||
]
|
||||
placeholders = {
|
||||
modality: [
|
||||
_PlaceholderInfo(
|
||||
modality=p.modality,
|
||||
item_idx=p.item_idx,
|
||||
start_idx=p.start_idx - 1,
|
||||
replacement=p.replacement,
|
||||
) for p in ps
|
||||
]
|
||||
for modality, ps in placeholders.items()
|
||||
}
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_pixtral_hf_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
def get_patch_size(self) -> int:
|
||||
return self.vision_config.patch_size
|
||||
|
||||
def get_patch_grid_length(self) -> int:
|
||||
return get_pixtral_hf_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_positions = hf_config.audio_config.max_source_positions
|
||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||
@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
]
|
||||
|
||||
def _always_apply_prompt_replacements(self) -> bool:
|
||||
# HF never applies prompt replacements, so we have to do it ourselves
|
||||
# _find_placeholders may incorrectly think that HF has already performed
|
||||
# processing for multi-audio input when the input audios are short
|
||||
# (the corresponding placeholders may take up fewer tokens than
|
||||
# the number of audio items)
|
||||
# HF never applies prompt replacements, so we have to do it ourselves.
|
||||
# NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
|
||||
# has already performed processing for multi-audio input when the input
|
||||
# audios are short (the corresponding placeholders may take up fewer
|
||||
# tokens than the number of audio items)
|
||||
return True
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors, VideoItem)
|
||||
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
|
||||
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
# === Vision input helpers === #
|
||||
|
||||
|
||||
def _get_vision_info(
|
||||
vision_config: Qwen2VLVisionConfig,
|
||||
height: int,
|
||||
width: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
*,
|
||||
do_resize: bool = True,
|
||||
modality: str = "image",
|
||||
mm_count: int = 1,
|
||||
):
|
||||
"""Get information (resized height / width and number of vision tokens)
|
||||
of input image / video frame."""
|
||||
patch_size = vision_config.patch_size
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=height,
|
||||
width=width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
else:
|
||||
resized_height, resized_width = height, width
|
||||
|
||||
if modality == "image":
|
||||
grid_t = mm_count
|
||||
elif modality == "video":
|
||||
grid_t = max(mm_count // temporal_patch_size, 1)
|
||||
else:
|
||||
raise ValueError(f"Modality {modality} is not supported")
|
||||
|
||||
grid_h = resized_height // patch_size
|
||||
grid_w = resized_width // patch_size
|
||||
vision_tokens = grid_t * grid_h * grid_w
|
||||
llm_num_vision_tokens = vision_tokens // (merge_size**2)
|
||||
|
||||
return resized_height, resized_width, llm_num_vision_tokens
|
||||
|
||||
|
||||
def _get_image_processor(hf_processor: Qwen2VLProcessor):
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
||||
return image_processor
|
||||
|
||||
|
||||
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
|
||||
@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def _get_max_mm_tokens(self, modality: str) -> int:
|
||||
def _get_vision_info(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
num_frames: int = 1,
|
||||
do_resize: bool = True,
|
||||
) -> tuple[ImageSize, int]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
patch_size = vision_config.patch_size
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
image_processor = self._get_image_processor(hf_processor)
|
||||
|
||||
_, _, max_llm_image_tokens = _get_vision_info(
|
||||
vision_config,
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
modality=modality,
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
)
|
||||
preprocessed_size = ImageSize(width=resized_width,
|
||||
height=resized_height)
|
||||
else:
|
||||
preprocessed_size = ImageSize(width=image_width,
|
||||
height=image_height)
|
||||
|
||||
grid_t = max(num_frames // temporal_patch_size, 1)
|
||||
grid_h = preprocessed_size.height // patch_size
|
||||
grid_w = preprocessed_size.width // patch_size
|
||||
|
||||
num_patches = grid_t * grid_h * grid_w
|
||||
num_vision_tokens = num_patches // (merge_size**2)
|
||||
|
||||
return preprocessed_size, num_vision_tokens
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
max_image_size, _ = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
)
|
||||
return max_llm_image_tokens
|
||||
return max_image_size
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
_, max_image_tokens = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
)
|
||||
return max_image_tokens
|
||||
|
||||
def _get_max_video_tokens(self, num_frames: int) -> int:
|
||||
_, max_video_tokens = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
return max_video_tokens
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
num_frames = 0
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
|
||||
if self._get_max_video_tokens(next_num_frames) > max_tokens:
|
||||
break
|
||||
|
||||
num_frames = next_num_frames
|
||||
|
||||
return num_frames
|
||||
|
||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||
|
||||
max_image_tokens = self._get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
max_image_tokens = self._get_max_image_tokens()
|
||||
|
||||
num_frames = self._get_dummy_num_frames(seq_len)
|
||||
max_video_tokens = self._get_max_video_tokens(num_frames)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self._get_max_mm_tokens("image"),
|
||||
"video": self._get_max_mm_tokens("video"),
|
||||
"image": max_image_tokens,
|
||||
"video": max_video_tokens,
|
||||
}
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return Qwen2MultiModalDataParser()
|
||||
|
||||
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
||||
return image_processor
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
max_pixels: Optional[int] = None,
|
||||
) -> Qwen2VLProcessor:
|
||||
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
image_processor = self._get_image_processor(hf_processor)
|
||||
|
||||
if min_pixels:
|
||||
image_processor.min_pixels = min_pixels
|
||||
@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
image_processor = self._get_image_processor(hf_processor)
|
||||
|
||||
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
|
||||
# image_token and video_token registered
|
||||
@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
image_token: str = hf_processor.image_token
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
factor=image_processor.patch_size * image_processor.merge_size,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
)
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
video_token: str = hf_processor.video_token
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=resized_width,
|
||||
height=resized_height,
|
||||
num_images=num_images)
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
prompt_text=image_token * num_images + video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_siglip_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
def get_patch_size(self) -> int:
|
||||
return self.vision_config.patch_size
|
||||
|
||||
def get_patch_grid_length(self) -> int:
|
||||
return get_siglip_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
@ -6,7 +6,6 @@ from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
|
||||
_AUDIO_TOKENS_PER_SECOND)
|
||||
@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
assert isinstance(audios, list)
|
||||
|
||||
if not audios:
|
||||
return super()._call_hf_processor(
|
||||
@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
# Already resampled by _get_hf_mm_data
|
||||
assert is_list_of(audios, np.ndarray)
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
# therefore we need to input text and audio one by one
|
||||
audio_features, audio_token_len = [], []
|
||||
@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ProcessingCache)
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_patches(self) -> int:
|
||||
def get_image_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_image_size(self) -> int:
|
||||
def get_patch_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_patch_grid_length(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class VisionLanguageConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
|
||||
|
||||
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> VisionLanguageConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
def __init__(self, data: Sequence[HfVideoItem]) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
def get_num_frames(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
def get_frame_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
||||
|
||||
if isinstance(image, Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
|
||||
|
||||
@ -16,7 +16,8 @@ from transformers import BatchFeature, ProcessorMixin
|
||||
|
||||
from vllm.inputs import DummyData, InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||
encode_tokens)
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@ -69,19 +70,6 @@ def _cached_encode(
|
||||
add_special_tokens=add_special_tokens)
|
||||
|
||||
|
||||
def _decode(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_ids: list[int],
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
|
||||
"""
|
||||
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_decode(
|
||||
tokenizer: AnyTokenizer,
|
||||
@ -89,9 +77,9 @@ def _cached_decode(
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> str:
|
||||
return _decode(tokenizer,
|
||||
list(token_ids),
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
return decode_tokens(tokenizer,
|
||||
list(token_ids),
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
class _HasModalityAttr(Protocol):
|
||||
@ -269,8 +257,10 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
|
||||
return self.match.end()
|
||||
|
||||
|
||||
class _PlaceholderInfo(NamedTuple):
|
||||
@dataclass
|
||||
class _PlaceholderInfo:
|
||||
modality: str
|
||||
item_idx: int
|
||||
start_idx: int
|
||||
replacement: list[int]
|
||||
|
||||
@ -311,12 +301,14 @@ def find_text_matches(
|
||||
|
||||
def _resolve_matches(
|
||||
prompt: _PromptSeq,
|
||||
matches: Sequence[_PromptReplacementMatch],
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
||||
) -> list[_PromptReplacementMatch]:
|
||||
"""
|
||||
Resolve :code:`matches` to ensure that there are no overlapping matches,
|
||||
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
|
||||
and sort them such that earlier matches take priority over later ones.
|
||||
"""
|
||||
matches = [m for matches in mm_matches.values() for m in matches]
|
||||
|
||||
seen_matches: list[Optional[_PromptReplacementMatch]] = [None
|
||||
] * len(prompt)
|
||||
|
||||
@ -334,14 +326,15 @@ def _resolve_matches(
|
||||
|
||||
def _replace_matches(
|
||||
prompt: _S,
|
||||
matches: Sequence[_PromptReplacementMatch],
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[_S]:
|
||||
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
||||
out_seqs = list[_S]()
|
||||
prev_end_idx = 0
|
||||
next_idx_by_modality = defaultdict[str, int](lambda: 0)
|
||||
|
||||
for match in _resolve_matches(prompt, matches):
|
||||
for match in _resolve_matches(prompt, mm_matches):
|
||||
modality = match.modality
|
||||
|
||||
item_idx = next_idx_by_modality[modality]
|
||||
@ -371,28 +364,28 @@ def _replace_matches(
|
||||
|
||||
def replace_token_matches(
|
||||
prompt: list[int],
|
||||
matches: Sequence[_PromptReplacementTokenMatch],
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
||||
if not mm_matches:
|
||||
return prompt
|
||||
|
||||
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
|
||||
token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
|
||||
|
||||
return flatten_2d_lists(token_id_seqs)
|
||||
|
||||
|
||||
def replace_text_matches(
|
||||
prompt: str,
|
||||
matches: Sequence[_PromptReplacementTextMatch],
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> str:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
||||
if not mm_matches:
|
||||
return prompt
|
||||
|
||||
texts = _replace_matches(prompt, matches, mm_item_counts)
|
||||
texts = _replace_matches(prompt, mm_matches, mm_item_counts)
|
||||
|
||||
return "".join(texts)
|
||||
|
||||
@ -407,14 +400,14 @@ def _iter_modality_placeholders(
|
||||
return
|
||||
|
||||
prompt_len = len(prompt)
|
||||
item_index = 0
|
||||
item_idx = 0
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < prompt_len:
|
||||
found = False
|
||||
|
||||
for repl_info in modality_repls:
|
||||
replacement = repl_info.get_replacement(item_index)
|
||||
replacement = repl_info.get_replacement(item_idx)
|
||||
repl_tokens = replacement.token_ids
|
||||
repl_len = len(repl_tokens)
|
||||
end_idx = start_idx + repl_len
|
||||
@ -425,12 +418,13 @@ def _iter_modality_placeholders(
|
||||
if prompt[start_idx:end_idx] == repl_tokens:
|
||||
yield _PlaceholderInfo(
|
||||
modality=modality,
|
||||
item_idx=item_idx,
|
||||
start_idx=start_idx,
|
||||
replacement=repl_tokens,
|
||||
)
|
||||
|
||||
item_index += 1
|
||||
if item_index >= modal_item_count:
|
||||
item_idx += 1
|
||||
if item_idx >= modal_item_count:
|
||||
return
|
||||
|
||||
# Exclude overlapping matches
|
||||
@ -442,28 +436,36 @@ def _iter_modality_placeholders(
|
||||
start_idx += 1
|
||||
|
||||
|
||||
def iter_placeholders(
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
def _iter_placeholders(
|
||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
||||
prompt: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
"""
|
||||
Yield each set of placeholder tokens found in :code:`prompt`.
|
||||
For each modality, yield each set of placeholder tokens found in
|
||||
:code:`prompt`.
|
||||
|
||||
Note that empty matches are ignored.
|
||||
"""
|
||||
repls_by_modality = dict(full_groupby_modality(prompt_repls))
|
||||
|
||||
for modality, modal_item_count in mm_item_counts.items():
|
||||
if modality in repls_by_modality:
|
||||
if modality in mm_prompt_repls:
|
||||
yield from _iter_modality_placeholders(
|
||||
prompt,
|
||||
modality,
|
||||
repls_by_modality[modality],
|
||||
mm_prompt_repls[modality],
|
||||
modal_item_count,
|
||||
)
|
||||
|
||||
|
||||
def find_mm_placeholders(
|
||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
||||
prompt: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[_PlaceholderInfo]]:
|
||||
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
|
||||
@ -620,7 +622,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum possible number of tokens per data item
|
||||
for each modality.
|
||||
@ -703,14 +705,14 @@ class BaseMultiModalProcessor(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _find_placeholders(
|
||||
def _find_mm_placeholders(
|
||||
self,
|
||||
all_prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
||||
new_token_ids: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[_PlaceholderInfo]:
|
||||
return list(
|
||||
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
|
||||
) -> Mapping[str, list[_PlaceholderInfo]]:
|
||||
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
|
||||
mm_item_counts)
|
||||
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
@ -797,7 +799,10 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
# Some HF processors (e.g. Qwen2-VL) expect corresponding
|
||||
# multi-modal tokens to be in the prompt text
|
||||
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)
|
||||
dummy_inputs = self._get_dummy_processor_inputs(
|
||||
self.ctx.model_config.max_model_len,
|
||||
mm_missing_counts,
|
||||
)
|
||||
|
||||
_, mm_missing_kwargs = self._apply_hf_processor(
|
||||
prompt_text=dummy_inputs.prompt_text,
|
||||
@ -889,50 +894,44 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_item_counts = mm_data_items.get_all_counts()
|
||||
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
for item_idx in range(item_count):
|
||||
try:
|
||||
mm_kwargs.get_item(modality, item_idx)
|
||||
except Exception as e:
|
||||
# Make it easy to set a breakpoint in the debugger
|
||||
raise e
|
||||
|
||||
return prompt_ids, mm_kwargs
|
||||
|
||||
def _bind_prompt_replacements(
|
||||
def _bind_and_group_repls(
|
||||
self,
|
||||
prompt_repls: list[PromptReplacement],
|
||||
) -> list[_BoundPromptReplacement]:
|
||||
) -> dict[str, list[_BoundPromptReplacement]]:
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
|
||||
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
def _always_apply_prompt_replacements(self) -> bool:
|
||||
"""
|
||||
A flag which can be overridden so that
|
||||
:meth:`_apply_prompt_replacements` is always called even if we
|
||||
detect that HF has performed processing via :meth:`_find_placeholders`.
|
||||
detect that HF has performed processing via
|
||||
:meth:`_find_placeholders_by_modality`.
|
||||
|
||||
This is useful in cases where :meth:`_find_placeholders` cannot be
|
||||
reliably used to detect whether HF has performed processing or not.
|
||||
This is useful in cases where :meth:`_find_placeholders_by_modality`
|
||||
cannot be reliably used to detect whether HF has performed processing.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
token_matches = find_token_matches(token_ids, prompt_repls)
|
||||
mm_token_matches = {
|
||||
modality: find_token_matches(token_ids, prompt_repls)
|
||||
for modality, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
mm_match_counts = {
|
||||
modality: len(matches)
|
||||
for modality, matches in full_groupby_modality(token_matches)
|
||||
for modality, matches in mm_token_matches.items()
|
||||
}
|
||||
|
||||
# If the search text does not represent a special token,
|
||||
@ -951,32 +950,92 @@ class BaseMultiModalProcessor(ABC):
|
||||
): # yapf: disable
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
token_matches,
|
||||
mm_token_matches,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
text = _decode(tokenizer, token_ids)
|
||||
matched_repls = [match.prompt_repl for match in token_matches]
|
||||
text = decode_tokens(tokenizer, token_ids)
|
||||
matched_repls = {
|
||||
modality: [match.prompt_repl for match in token_matches]
|
||||
for modality, token_matches in mm_token_matches.items()
|
||||
}
|
||||
else:
|
||||
text = _decode(tokenizer, token_ids)
|
||||
text = decode_tokens(tokenizer, token_ids)
|
||||
|
||||
text_matches = find_text_matches(text, prompt_repls)
|
||||
mm_text_matches = {
|
||||
modality: find_text_matches(text, prompt_repls)
|
||||
for modality, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
text = replace_text_matches(
|
||||
text,
|
||||
text_matches,
|
||||
mm_text_matches,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
token_ids = encode_tokens(tokenizer,
|
||||
text,
|
||||
add_special_tokens=False)
|
||||
matched_repls = [match.prompt_repl for match in text_matches]
|
||||
matched_repls = {
|
||||
modality: [match.prompt_repl for match in token_matches]
|
||||
for modality, token_matches in mm_text_matches.items()
|
||||
}
|
||||
|
||||
placeholders = self._find_placeholders(matched_repls, token_ids,
|
||||
mm_item_counts)
|
||||
placeholders = self._find_mm_placeholders(
|
||||
matched_repls,
|
||||
token_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def _validate_mm_kwargs(
|
||||
self,
|
||||
mm_kwargs: MultiModalKwargs,
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> None:
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
if modality in mm_kwargs.modalities:
|
||||
items = mm_kwargs.get_items(modality)
|
||||
else:
|
||||
items = []
|
||||
|
||||
if len(items) != item_count:
|
||||
raise RuntimeError(
|
||||
f"Expected there to be {item_count} {modality} items in "
|
||||
f"keyword arguments corresponding to {item_count} "
|
||||
f"{modality} data items, but only found {len(items)}! "
|
||||
"There is likely a problem with your "
|
||||
"implementation of merged multi-modal processor for this "
|
||||
"model (usually arising from an inconsistency between "
|
||||
"`_call_hf_processor` and `_get_mm_fields_config`).")
|
||||
|
||||
def _validate_mm_placeholders(
|
||||
self,
|
||||
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
*,
|
||||
allow_missing: bool = False,
|
||||
) -> Mapping[str, int]:
|
||||
missing_repl_counts = dict[str, int]()
|
||||
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
placeholders = mm_placeholders.get(modality, [])
|
||||
|
||||
if len(placeholders) != item_count and not allow_missing:
|
||||
raise RuntimeError(
|
||||
f"Expected there to be {item_count} prompt replacements "
|
||||
f"corresponding to {item_count} {modality} items, but only "
|
||||
f"found {len(placeholders)} prompt replacements! Either "
|
||||
"the prompt text has missing/incorrect tokens for "
|
||||
"multi-modal inputs, or there is a problem with your "
|
||||
"implementation of merged multi-modal processor for this "
|
||||
"model (usually arising from an inconsistency between "
|
||||
"`_call_hf_processor` and `_get_prompt_replacements`).")
|
||||
|
||||
missing_repl_counts[modality] = item_count - len(placeholders)
|
||||
|
||||
return missing_repl_counts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
@ -1009,56 +1068,69 @@ class BaseMultiModalProcessor(ABC):
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls)
|
||||
mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
|
||||
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||
|
||||
hf_mm_placeholders = self._find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
prompt_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
if self._always_apply_prompt_replacements():
|
||||
mm_missing_repl_counts = mm_item_counts
|
||||
mm_missing_repls = dict(mm_prompt_repls)
|
||||
else:
|
||||
mm_missing_repl_counts = self._validate_mm_placeholders(
|
||||
hf_mm_placeholders,
|
||||
mm_item_counts,
|
||||
allow_missing=True,
|
||||
)
|
||||
|
||||
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]()
|
||||
for modality, missing_repl_count in mm_missing_repl_counts.items():
|
||||
if missing_repl_count == 0:
|
||||
mm_missing_repls[modality] = []
|
||||
elif missing_repl_count == mm_item_counts.get(modality, 0):
|
||||
mm_missing_repls[modality] = mm_prompt_repls[modality]
|
||||
else:
|
||||
raise ValueError("Partial prompt replacement within "
|
||||
f"{modality=} is not supported")
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
|
||||
if all_placeholders and not self._always_apply_prompt_replacements():
|
||||
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
|
||||
tokenizer = self._get_tokenizer()
|
||||
prompt_text = _decode(tokenizer, prompt_ids)
|
||||
prompt_text = decode_tokens(tokenizer, prompt_ids)
|
||||
mm_placeholders = hf_mm_placeholders
|
||||
else:
|
||||
(
|
||||
prompt_ids,
|
||||
prompt_text,
|
||||
all_placeholders,
|
||||
missing_mm_placeholders,
|
||||
) = self._apply_prompt_replacements(
|
||||
prompt_ids,
|
||||
prompt_repls,
|
||||
mm_item_counts,
|
||||
mm_missing_repls,
|
||||
mm_missing_repl_counts,
|
||||
)
|
||||
|
||||
mm_placeholders = dict[str, list[PlaceholderRange]]()
|
||||
err_suffix = ("This suggests a problem with your implementation of "
|
||||
"the merged multi-modal processor for this model, "
|
||||
"particularly in the `_get_prompt_replacements` method.")
|
||||
mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
|
||||
|
||||
for modality, placeholders in full_groupby_modality(all_placeholders):
|
||||
if modality not in mm_items:
|
||||
raise AssertionError(
|
||||
f"Expected no placeholders for {modality=}, "
|
||||
f"but found {placeholders=}. Input items: {mm_items}"
|
||||
f"\n{err_suffix}")
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
if len(placeholders) != len(mm_items[modality]):
|
||||
raise AssertionError(
|
||||
f"Expected length of {placeholders=} for {modality=} "
|
||||
f"to equal that of input items: {mm_items[modality]}"
|
||||
f"\n{err_suffix}")
|
||||
|
||||
mm_placeholders[modality] = [
|
||||
item.to_range() for item in placeholders
|
||||
]
|
||||
mm_placeholder_ranges = {
|
||||
modality: [item.to_range() for item in placeholders]
|
||||
for modality, placeholders in mm_placeholders.items()
|
||||
}
|
||||
|
||||
return MultiModalInputsV2(
|
||||
type="multimodal",
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
|
||||
def _get_dummy_audios(
|
||||
@ -1092,8 +1164,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
return [video] * num_videos
|
||||
|
||||
@abstractmethod
|
||||
def _get_dummy_mm_inputs(
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
@ -1121,12 +1194,25 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
return mm_limits
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalInputsV2:
|
||||
processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts)
|
||||
|
||||
return self.apply(
|
||||
prompt_text=processor_inputs.prompt_text,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
def get_dummy_data(self, seq_len: int) -> DummyData:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
mm_counts = self._get_and_validate_dummy_mm_counts()
|
||||
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item()
|
||||
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len)
|
||||
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
||||
raise AssertionError(
|
||||
"The keys returned by `get_supported_mm_limits`"
|
||||
@ -1134,13 +1220,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
"returned by `get_mm_max_tokens_per_item` "
|
||||
f"({set(mm_max_tokens_per_item.keys())})")
|
||||
|
||||
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
|
||||
mm_inputs = self.apply(
|
||||
prompt_text=processor_inputs.prompt_text,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
@ -1171,6 +1251,12 @@ class BaseMultiModalProcessor(ABC):
|
||||
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
|
||||
total_len, total_placeholders_by_modality)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
||||
multi_modal_data=None,
|
||||
multi_modal_placeholders=None,
|
||||
)
|
||||
|
||||
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
|
||||
|
||||
return DummyData(
|
||||
|
||||
@ -223,7 +223,8 @@ class MultiModalRegistry:
|
||||
if self.has_processor(model_config):
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
processor = self.create_processor(model_config, tokenizer)
|
||||
return processor.get_mm_max_tokens_per_item()
|
||||
seq_len = model_config.max_model_len
|
||||
return processor.get_mm_max_tokens_per_item(seq_len)
|
||||
|
||||
return {
|
||||
key: plugin.get_max_multimodal_tokens(model_config)
|
||||
|
||||
@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
MistralTokenizer]
|
||||
|
||||
|
||||
def decode_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_ids: list[int],
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
|
||||
"""
|
||||
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
def encode_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user