[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-04 19:40:53 +08:00 committed by GitHub
parent 300acb8347
commit eed11ebee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1104 additions and 973 deletions

View File

@ -0,0 +1,58 @@
import pytest
from PIL import Image
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_next():
from vllm.model_executor.models.llava_next import (
LlavaNextMultiModalProcessor)
return LlavaNextMultiModalProcessor
# FIXME: image_size [(198, 176), (176, 198)]
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_next,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaNextMultiModalProcessor handles prompt replacement properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_next(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token
assert first_placeholder["offset"] == 1
assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs

View File

@ -0,0 +1,59 @@
import pytest
from PIL import Image
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_onevision():
from vllm.model_executor.models.llava_onevision import (
LlavaOnevisionMultiModalProcessor)
return LlavaOnevisionMultiModalProcessor
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_onevision,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement
properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_onevision(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token
assert first_placeholder["offset"] == 0
assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs

View File

@ -1,6 +1,4 @@
"""Tests for phi3v's multimodal preprocessing kwargs.""" """Tests for phi3v's multimodal preprocessing kwargs."""
from typing import Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -10,8 +8,6 @@ from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
models = ["microsoft/Phi-3.5-vision-instruct"]
# Wrap lazy imports to avoid initializing CUDA during test collection # Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture() @pytest.fixture()
@ -20,40 +16,40 @@ def processor_for_phi3v():
return Phi3VMultiModalProcessor return Phi3VMultiModalProcessor
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
# yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_crops,expected_toks_per_img", ("mm_processor_kwargs", "expected_toks_per_img"),
[ [
(4, 757), ({"num_crops": 4}, 757),
(16, 1921), ({"num_crops": 16}, 1921),
# the default num_crops of phi-3.5-vision is 4 # the default num_crops of phi-3.5-vision is 4
(None, 757), ({}, 757),
]) ])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets, def test_processor_override(
model: str, num_crops: Optional[int], processor_for_phi3v,
expected_toks_per_img: int, num_imgs: int): image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
):
"""Ensure input_processor_for_phi3v handles num_crops properly.""" """Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context( ctx = build_model_context(
model_name=model, model_name=model_id,
tokenizer_name=model, tokenizer_name=model_id,
trust_remote_code=True, trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass # Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
mm_data = {"image": images}
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs = {"num_crops": num_crops}
processor = processor_for_phi3v(ctx) processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

View File

@ -1,5 +1,3 @@
from typing import Any, Dict, Tuple
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -8,56 +6,45 @@ from vllm.inputs import InputProcessingContext
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
MODEL = "Qwen/Qwen2-VL-2B-Instruct"
MIN_PIXELS = "min_pixels"
MAX_PIXELS = "max_pixels"
# Fixtures lazy import to avoid initializing CUDA during test collection # Fixtures lazy import to avoid initializing CUDA during test collection
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
# input mappers.
@pytest.fixture() @pytest.fixture()
def processor_for_qwen2_vl(): def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor return Qwen2VLMultiModalProcessor
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [ ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [
({}, 1426, (5704, 1176)), ({}, 1426, (5704, 1176)),
({ ({"min_pixels": 64**2, "max_pixels": 512**2}, 330, (1320, 1176)),
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330, (1320, 1176)),
]) ])
@pytest.mark.parametrize("model", [MODEL]) # yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override( def test_processor_override(
processor_for_qwen2_vl, processor_for_qwen2_vl,
image_assets: _ImageAssets, image_assets: _ImageAssets,
model: str, model_id: str,
mm_processor_kwargs: Dict[str, Any], mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int, expected_toks_per_img: int,
expected_pixels_shape: Tuple[int, int], expected_pixels_shape: tuple[int, int],
num_imgs: int, num_imgs: int,
): ):
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly.""" """Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context( ctx = build_model_context(
model_name=model, model_name=model_id,
tokenizer_name=model, tokenizer_name=model_id,
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass # Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
images = [image_assets[0].pil_image] * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
mm_data = {"image": images}
processor = processor_for_qwen2_vl(ctx) processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

View File

@ -274,10 +274,8 @@ VLM_TEST_SETTINGS = {
), ),
limit_mm_per_prompt={"image": 4}, limit_mm_per_prompt={"image": 4},
)], )],
# Llava-next tests fixed sizes & the default size factors
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
), ),
"llava_one_vision": VLMTestInfo( "llava_onevision": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.CUSTOM_INPUTS, test_type=VLMTestType.CUSTOM_INPUTS,
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
@ -288,8 +286,6 @@ VLM_TEST_SETTINGS = {
), ),
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
# Llava-one-vision tests fixed sizes & the default size factors
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
custom_test_opts=[CustomTestOptions( custom_test_opts=[CustomTestOptions(
inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs(
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
@ -306,7 +302,6 @@ VLM_TEST_SETTINGS = {
max_model_len=4096, max_model_len=4096,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
), ),
"mantis": VLMTestInfo( "mantis": VLMTestInfo(
models=["TIGER-Lab/Mantis-8B-siglip-llama3"], models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
@ -431,7 +426,7 @@ VLM_TEST_SETTINGS = {
) for inp in custom_inputs.different_patch_input_cases_internvl() ) for inp in custom_inputs.different_patch_input_cases_internvl()
], ],
), ),
"llava_one_vision-multiple-images": VLMTestInfo( "llava_onevision-multiple-images": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.CUSTOM_INPUTS, test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384, max_model_len=16384,

View File

@ -427,130 +427,3 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
mm_limit=1, mm_limit=1,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
def run_chunked_prefill_test(
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Compare inference result between
chunked prefill disabled and chunked prefill enabled
"""
# NOTE:
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None)
for prompts, images, videos in inputs
]
with vllm_runner(
model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enable_chunked_prefill=True,
# should be small enough to ensure prefilling is chunked
max_num_batched_tokens=32,
mm_processor_kwargs={
"max_pixels": 16 * 28 * 28,
}) as vllm_model_chunked:
outputs_per_case_chunked = [
vllm_model_chunked.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None) for prompts, images, videos in inputs
]
for outputs, \
outputs_chunked \
in zip(outputs_per_case,
outputs_per_case_chunked):
check_logprobs_close(
outputs_0_lst=outputs,
outputs_1_lst=outputs_chunked,
name_0="non_chunked",
name_1="chunked",
)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [1])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
"""
Test Qwen2-VL's chunked prefill with M-RoPE
"""
prompts = [
qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt)
for prompt in example_prompts[:1]
]
# 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
# so an image is included in the inputs
# 2. however, Qwen2-VL currently won't work properly
# when chunked prefill is enabled and there are some multi-modal inputs,
# here use a hacky way: provide a **zero-length** image to make it happy
#
# and finally we achieved:
# (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests
zero_len_image = {
"image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)),
"image_grid_thw": torch.tensor([[0, 0, 0]])
}
images = [zero_len_image] * len(prompts)
inputs_per_case: List[Tuple[List[str], PromptImageInput,
PromptVideoInput]] = [
(prompts, images, []),
]
run_chunked_prefill_test(
vllm_runner,
inputs_per_case,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)

View File

@ -11,8 +11,8 @@ from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_text_matches, _PlaceholderInfo, find_mm_placeholders,
find_token_matches, iter_placeholders, find_text_matches, find_token_matches,
iter_token_matches, iter_token_matches,
replace_text_matches, replace_text_matches,
replace_token_matches) replace_token_matches)
@ -314,21 +314,27 @@ def test_find_replace_text(
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ mm_prompt_repls = {
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) key: [
for key, target in target_by_key.items() PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
] ]
matches = find_text_matches(prompt, prompt_repls) for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
result = replace_text_matches( result = replace_text_matches(
prompt, prompt,
matches, mm_matches,
{key: mm_count {key: mm_count
for key in repl_by_key}, for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("matches:", matches) print("mm_matches:", mm_matches)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
@ -380,21 +386,27 @@ def test_find_replace_tokens(
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ mm_prompt_repls = {
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) key: [
for key, target in target_by_key.items() PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
] ]
matches = find_token_matches(prompt, prompt_repls) for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
result = replace_token_matches( result = replace_token_matches(
prompt, prompt,
matches, mm_matches,
{key: mm_count {key: mm_count
for key in repl_by_key}, for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("matches:", matches) print("mm_matches:", mm_matches)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
@ -417,58 +429,76 @@ def test_find_replace_tokens(
[ [
( (
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
[ {
"pattern_1": [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0,
start_idx=6, start_idx=6,
replacement=[32000, 32000], replacement=[32000, 32000],
), ),
], ],
}
), ),
( (
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[ {
"pattern_1": [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0,
start_idx=1, start_idx=1,
replacement=[32000, 32000], replacement=[32000, 32000],
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1,
start_idx=5, start_idx=5,
replacement=[32000, 32000], replacement=[32000, 32000],
), ),
],
"pattern_3": [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
item_idx=0,
start_idx=7, start_idx=7,
replacement=[1550, 918, 1550], replacement=[1550, 918, 1550],
), ),
], ],
}
), ),
( (
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[ {
"pattern_1": [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0,
start_idx=1, start_idx=1,
replacement=[32000, 32000], replacement=[32000, 32000],
), ),
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1,
start_idx=3, start_idx=3,
replacement=[32000, 32000], replacement=[32000, 32000],
), ),
],
"pattern_3": [
_PlaceholderInfo( _PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
item_idx=0,
start_idx=6, start_idx=6,
replacement=[1550, 918, 1550], replacement=[1550, 918, 1550],
), ),
], ],
}
), ),
] ]
) )
# yapf: enable # yapf: enable
def test_iter_placeholders( def test_find_mm_placeholders(
repl_by_key, repl_by_key,
prompt, prompt,
expected, expected,
@ -476,19 +506,18 @@ def test_iter_placeholders(
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ mm_prompt_repls = {
PromptReplacement(key, [], repl).bind(mock_tokenizer) key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
] }
result = list( result = find_mm_placeholders(
iter_placeholders( mm_prompt_repls,
prompt_repls,
prompt, prompt,
# Effectively match all occurrences in the prompt # Effectively match all occurrences in the prompt
{key: 3 {key: 3
for key in repl_by_key}, for key in repl_by_key},
)) )
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
@ -694,7 +723,10 @@ def _test_processing_cache_correctness(
} }
mm_counts = {k: len(vs) for k, vs in mm_data.items()} mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text prompt = baseline_processor._get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text
# Drop unnecessary keys and test single -> multi conversion # Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate: if rng.rand() < simplify_rate:
@ -728,6 +760,8 @@ def _test_processing_cache_correctness(
("adept/fuyu-8b", {"image": False}), ("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}), ("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}), ("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),

View File

@ -456,7 +456,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
hf_config = self.ctx.get_hf_config() hf_config = self.ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values()) return max(hf_config.projector_patch_to_query_dict.values())
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_mm_fields_config( def _get_mm_fields_config(
@ -488,8 +488,9 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config() hf_config = self.ctx.get_hf_config()

View File

@ -405,7 +405,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
hf_config = self.ctx.get_hf_config(Blip2Config) hf_config = self.ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens return hf_config.num_query_tokens
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> Blip2Processor: def _get_hf_processor(self) -> Blip2Processor:
@ -457,8 +457,9 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
return result return result
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config) hf_config = self.ctx.get_hf_config(Blip2Config)

View File

@ -57,7 +57,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
processor = self._get_hf_processor() processor = self._get_hf_processor()
return processor.image_seq_length return processor.image_seq_length
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> ChameleonProcessor: def _get_hf_processor(self) -> ChameleonProcessor:
@ -90,8 +90,9 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig) config = self.ctx.get_hf_config(ChameleonConfig)

View File

@ -164,15 +164,18 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_clip_image_tokens(self.vision_config) return get_max_clip_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_clip_patch_grid_length( return get_clip_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module): class CLIPVisionEmbeddings(nn.Module):

View File

@ -96,7 +96,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
nrows = math.ceil(image_height / 30) nrows = math.ceil(image_height / 30)
return ncols, nrows return ncols, nrows
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_target_size() target_width, target_height = self._get_image_target_size()
max_ncols, max_nrows = self._get_image_feature_grid_size( max_ncols, max_nrows = self._get_image_feature_grid_size(
@ -208,8 +208,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
return result return result
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
target_width, target_height = self._get_image_target_size() target_width, target_height = self._get_image_target_size()

View File

@ -25,11 +25,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize) ImageSize)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (InputProcessingContext,
InputProcessingContext,
MultiModalDataItems, ProcessingCache, MultiModalDataItems, ProcessingCache,
ProcessorInputs, PromptReplacement, ProcessorInputs, PromptReplacement)
full_groupby_modality)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
@ -39,7 +37,7 @@ from .pixtral import (PixtralHFVisionModel,
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import vision_encoder_info from .vision import BaseVisionLanguageMultiModalProcessor
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
@ -100,19 +98,7 @@ class LlavaLikeConfig(Protocol):
vision_feature_layer: Final[Union[int, List[int]]] vision_feature_layer: Final[Union[int, List[int]]]
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod @abstractmethod
def _get_hf_config(self) -> LlavaLikeConfig: def _get_hf_config(self) -> LlavaLikeConfig:
@ -121,6 +107,19 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _apply_feature_select_strategy( def _apply_feature_select_strategy(
self, self,
strategy: str, strategy: str,
@ -142,19 +141,6 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
self._vision_encoder_info.get_max_image_tokens(), self._vision_encoder_info.get_max_image_tokens(),
) )
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_dummy_image_size(self) -> ImageSize: def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size() image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size) return ImageSize(image_size, image_size)
@ -163,8 +149,9 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def _get_image_token(self) -> str: def _get_image_token(self) -> str:
raise NotImplementedError raise NotImplementedError
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
@ -709,7 +696,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
"</Image>)", # 3 tokens "</Image>)", # 3 tokens
]) ])
mantis_repls = self._bind_prompt_replacements([ mantis_mm_repls = self._bind_and_group_repls([
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id] * num_image_tokens, target=[image_token_id] * num_image_tokens,
@ -719,7 +706,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt_ids, prompt_text, _ = self._apply_prompt_replacements( prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
result["prompt_token_ids"], result["prompt_token_ids"],
mantis_repls, mantis_mm_repls,
mm_item_counts, mm_item_counts,
) )
@ -728,15 +715,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
orig_repls = self._bind_prompt_replacements(unbound_orig_repls) orig_repls = self._bind_and_group_repls(unbound_orig_repls)
all_placeholders = self._find_placeholders(orig_repls, prompt_ids, mm_placeholders = self._find_mm_placeholders(
mm_item_counts) orig_repls,
assert len(all_placeholders) == mm_item_counts.get("image", 0) prompt_ids,
mm_item_counts,
)
mm_placeholders = { self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders) mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
for modality, placeholders in mm_placeholders.items()
} }
return MultiModalInputsV2( return MultiModalInputsV2(
@ -744,7 +735,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders, mm_placeholders=mm_placeholder_ranges,
) )

View File

@ -67,9 +67,6 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> LlavaNextProcessor: def _get_hf_processor(self) -> LlavaNextProcessor:
return self.ctx.get_hf_processor(LlavaNextProcessor) return self.ctx.get_hf_processor(LlavaNextProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
@ -81,6 +78,9 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_max_image_tokens(self) -> int: def _get_max_image_tokens(self) -> int:
largest_feature_size, _ = self._get_pinpoint_with_most_features() largest_feature_size, _ = self._get_pinpoint_with_most_features()
return largest_feature_size return largest_feature_size
@ -97,20 +97,20 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
image_height: int, image_height: int,
) -> int: ) -> int:
hf_config = self._get_hf_config() hf_config = self._get_hf_config()
vision_encoder_info = self._vision_encoder_info
base_feature_size = self._apply_feature_select_strategy( base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy, hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens( vision_encoder_info.get_num_image_tokens(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
), ),
) )
num_patches = self._vision_encoder_info.get_num_patches()
num_patch_height, num_patch_width = get_anyres_image_grid_shape( num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(image_height, image_width), image_size=(image_height, image_width),
grid_pinpoints=hf_config.image_grid_pinpoints, grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=self._vision_encoder_info.get_image_size(), patch_size=vision_encoder_info.get_image_size(),
) )
( (
@ -119,7 +119,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
) = self._get_num_unpadded_features( ) = self._get_num_unpadded_features(
original_height=image_height, original_height=image_height,
original_width=image_width, original_width=image_width,
npatches=num_patches, npatches=vision_encoder_info.get_patch_grid_length(),
num_patch_height=num_patch_height, num_patch_height=num_patch_height,
num_patch_width=num_patch_width, num_patch_width=num_patch_width,
) )
@ -155,6 +155,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
unpadded_features = current_height * current_width unpadded_features = current_height * current_width
newline_features = current_height newline_features = current_height
return (unpadded_features, newline_features) return (unpadded_features, newline_features)
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:

View File

@ -3,38 +3,32 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, from transformers import (BatchFeature, LlavaNextVideoConfig,
SiglipVisionConfig) LlavaNextVideoProcessor)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
repeat_and_pad_placeholder_tokens) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import SiglipVisionModel
dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import BaseVisionLanguageMultiModalProcessor
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
_MAX_NUM_VIDEOS = 1
class LlavaNextVideoPixelInputs(TypedDict): class LlavaNextVideoPixelInputs(TypedDict):
@ -50,144 +44,149 @@ class LlavaNextVideoPixelInputs(TypedDict):
""" """
def get_llava_next_video_frame_feature_size( class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
hf_config: LlavaNextVideoConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig def _get_hf_config(self) -> LlavaNextVideoConfig:
image_size = hf_config.vision_config.image_size return self.ctx.get_hf_config(LlavaNextVideoConfig)
patch_size = hf_config.vision_config.patch_size
def _get_hf_processor(self) -> LlavaNextVideoProcessor:
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
return {"video": max_video_tokens}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
def _get_num_frame_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
spatial_pool_stride = hf_config.spatial_pool_stride spatial_pool_stride = hf_config.spatial_pool_stride
return int((image_size / patch_size / spatial_pool_stride)**2) patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_max_llm_tokens(ctx: InputContext) -> int: def _get_num_video_tokens(
""" self,
Calculated from the maximum video frames under the context length *,
constraints of the language model. image_width: int,
""" image_height: int,
hf_text_config = ctx.model_config.hf_text_config num_frames: int,
model_config = ctx.model_config ) -> int:
max_tokens = model_config.max_model_len num_frame_tokens = self._get_num_frame_tokens(
rope_scaling = model_config.rope_scaling image_width=image_width,
image_height=image_height,
)
if rope_scaling: return num_frame_tokens * num_frames
rope_scaling_factor = hf_text_config.rope_scaling["factor"]
def _get_max_video_tokens(self, num_frames: int) -> int:
return self._get_num_video_tokens(image_width=999999,
image_height=999999,
num_frames=num_frames)
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_total_frames = self._get_max_video_frames(seq_len)
return max(max_total_frames // max(max_videos, 1), 1)
def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size)
def _get_video_token(self) -> str:
return self._get_hf_processor().video_token
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
video_token_id = hf_config.video_token_index
def get_replacement(item_idx: int):
videos = mm_items.get_items(
"video", (VideoEmbeddingItems, VideoProcessorItems))
if isinstance(videos, VideoEmbeddingItems):
num_video_tokens = videos.get_feature_size(item_idx)
else: else:
rope_scaling_factor = 1 image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
max_tokens *= rope_scaling_factor image_width=image_size.width,
image_height=image_size.height,
return max_tokens num_frames=videos.get_num_frames(item_idx),
def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
# Currently set to 32 frames
# TODO: max_tokens = _get_max_llm_tokens(ctx)
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos != _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
frames_per_video = _MAX_FRAMES_PER_VIDEO
# num_images = num_videos * frames_per_video
# fills the sequence with as longer video data as possible
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
) )
pil_frame = dummy_image_for_clip(vision_config, num_images=1) return [video_token_id] * num_video_tokens
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) return [
mm_data = {"video": mm_data_per_video} PromptReplacement(
return DummyData(seq_data, mm_data, ranges) modality="video",
elif isinstance(vision_config, SiglipVisionConfig): target=[video_token_id],
seq_data, ranges = dummy_seq_data_for_siglip( replacement=get_replacement,
vision_config, ),
seq_len, ]
num_videos,
image_token_id=hf_config.video_token_index, def _get_dummy_processor_inputs(
image_feature_size_override=video_feature_size, self,
mm_key="video", seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
video_token = self._get_video_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
) )
}
pil_frame = dummy_image_for_siglip(vision_config, num_images=1) return ProcessorInputs(
np_frame = np.array(pil_frame["image"]) prompt_text=video_token * num_videos,
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data=mm_data,
mm_data = {"video": mm_data_per_video}
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava_next_video(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "video" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
frame_feature_size = \
get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = num_frames * frame_feature_size
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
) )
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
# adopted from transformers modeling_llava_next_video.py # adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module): class LlavaNextVideoPooler(nn.Module):
@ -246,11 +245,7 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("video") @MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -3,47 +3,36 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from transformers import (BatchFeature, LlavaOnevisionConfig,
from transformers import (CLIPVisionConfig, LlavaOnevisionConfig, LlavaOnevisionProcessor)
SiglipVisionConfig)
from transformers.models.llava_onevision.modeling_llava_onevision import ( from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
repeat_and_pad_placeholder_tokens) VideoProcessorItems)
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, from .clip import CLIPVisionModel
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .llava_next import LlavaNextMultiModalProcessor
dummy_video_for_siglip, get_siglip_image_feature_size, from .siglip import SiglipVisionModel
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict): class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
@ -92,27 +81,69 @@ LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
LlavaOnevisionVideoPixelInputs] LlavaOnevisionVideoPixelInputs]
def _get_llava_onevision_image_unppaded_feature_size(height, width, patches, class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
scale_height,
scale_width):
current_height = patches * scale_height
current_width = patches * scale_width
original_aspect_ratio = width / height def _get_hf_config(self) -> LlavaOnevisionConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig)
def _get_hf_processor(self) -> LlavaOnevisionProcessor:
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
return {
"image": max_image_tokens,
"video": max_video_tokens,
}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.batched("video"),
)
def _get_num_unpadded_features(
self,
*,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio: if original_aspect_ratio > current_aspect_ratio:
new_height = int(height * (current_width / width)) new_height = int(original_height *
(current_width / original_width))
padding = (current_height - new_height) // 2 padding = (current_height - new_height) // 2
current_height -= padding * 2 current_height -= padding * 2
else: else:
new_width = int(width * (current_height / height)) new_width = int(original_width *
(current_height / original_height))
padding = (current_width - new_width) // 2 padding = (current_width - new_width) // 2
current_width -= padding * 2 current_width -= padding * 2
unpadded_features = current_height * current_width unpadded_features = current_height * current_width
newline_features = current_height newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * patches**2)) ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
if ratio > 1.1: if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int( unpadded_features = int(current_height // ratio) * int(
current_width // ratio) current_width // ratio)
@ -120,259 +151,182 @@ def _get_llava_onevision_image_unppaded_feature_size(height, width, patches,
return (unpadded_features, newline_features) return (unpadded_features, newline_features)
def _get_num_frame_tokens(
def get_llava_onevision_image_feature_size( self,
hf_config: LlavaOnevisionConfig,
*, *,
input_height: int, image_width: int,
input_width: int, image_height: int,
) -> int: ) -> int:
vision_config = hf_config.vision_config hf_config = self._get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
if isinstance(vision_config, CLIPVisionConfig): patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
num_patches = get_clip_patch_grid_length( pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
image_size=vision_config.image_size,
patch_size=vision_config.patch_size, return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
) -> int:
num_frame_tokens = self._get_num_frame_tokens(
image_width=image_width,
image_height=image_height,
) )
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig): return num_frame_tokens * num_frames + 1 # Newline token
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size, def _get_max_video_tokens(self, num_frames: int) -> int:
patch_size=vision_config.patch_size, return self._get_num_video_tokens(image_width=999999,
image_height=999999,
num_frames=num_frames)
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def _get_video_token(self) -> str:
return self._get_hf_processor().video_token
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
assert isinstance(videos, list)
if not videos:
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
) )
base_feature_size = get_siglip_image_feature_size(vision_config)
video_token = self._get_video_token()
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
text_image_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
pixel_values_videos = []
for video in videos:
item_processor_data = dict(prompt=video_token, videos=video)
item_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)
pixel_values_videos.append(
item_outputs.pop("pixel_values_videos")[0])
combined_outputs = dict(
**text_image_outputs,
pixel_values_videos=pixel_values_videos,
)
return BatchFeature(combined_outputs)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
image_repls = super()._get_prompt_replacements(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
)
hf_config = self._get_hf_config()
video_token_id = hf_config.video_token_index
def get_video_replacement(item_idx: int):
videos = mm_items.get_items(
"video", (VideoEmbeddingItems, VideoProcessorItems))
if isinstance(videos, VideoEmbeddingItems):
num_video_tokens = videos.get_feature_size(item_idx)
else: else:
msg = f"Unsupported vision config: {type(vision_config)}" image_size = videos.get_frame_size(item_idx)
raise NotImplementedError(msg) num_video_tokens = self._get_num_video_tokens(
image_width=image_size.width,
strategy = hf_config.vision_feature_select_strategy image_height=image_size.height,
if strategy == "default": num_frames=videos.get_num_frames(item_idx),
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
) )
( return [video_token_id] * num_video_tokens
unpadded_feature_size,
newline_feature_size,
) = _get_llava_onevision_image_unppaded_feature_size(
input_height, input_width, num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size return image_repls + [
PromptReplacement(
modality="video",
def get_max_llava_onevision_image_tokens(ctx: InputContext): target=[video_token_id],
return get_llava_onevision_image_feature_size( replacement=get_video_replacement,
ctx.get_hf_config(LlavaOnevisionConfig), ),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def get_llava_onevision_video_frame_feature_size(
hf_config: LlavaOnevisionConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride if hasattr(
hf_config, "spatial_pool_stride") else 2
height = width = image_size // patch_size
return math.ceil(height / spatial_pool_stride) * math.ceil(
width / spatial_pool_stride)
def get_llava_onevision_video_tokens(ctx: InputContext,
num_frames: int) -> int:
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
# TODO: support configuring (not supported by HF right now)
num_token_image_newline = 1
tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
return video_feature_size
def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
num_videos = mm_counts["video"]
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_image(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_onevision_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_onevision_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
] ]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
if isinstance(vision_config, CLIPVisionConfig): image_token = self._get_image_token()
return input_processor_for_clip( video_token = self._get_video_token()
model_config, target_width, target_height = self._get_dummy_image_size()
vision_config,
inputs, mm_data = {
image_token_id=hf_config.image_token_index, "image":
image_feature_size_override=image_feature_size, self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
) )
elif isinstance(vision_config, SiglipVisionConfig): }
return input_processor_for_siglip(
model_config, return ProcessorInputs(
vision_config, prompt_text=image_token * num_images + video_token * num_videos,
inputs, mm_data=mm_data,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
) )
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_video(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
video_feature_size = []
for video in video_data:
num_frames = video.shape[0]
video_feature_size.append(
get_llava_onevision_video_tokens(ctx, num_frames))
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
else:
raise TypeError(f"Invalid video type: {type(video_data)}")
msg = f"Unsupported video type: {type(video_data)}"
raise NotImplementedError(msg)
def input_processor_for_llava_onevision(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or ("video" not in multi_modal_data
and "image" not in multi_modal_data):
return inputs
if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, inputs)
if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, inputs)
msg = "Unsupported multi data type"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionMultiModalProjector(nn.Module):
@ -394,14 +348,7 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_llava_onevision_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -323,7 +323,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_height, height=image_height,
) )
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens( max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
@ -415,12 +415,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]: ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids, token_ids=token_ids,
prompt_repls=prompt_repls, mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts, mm_item_counts=mm_item_counts,
) )
@ -428,15 +428,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
if text.startswith("<s> <|image|>"): if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1) text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]] token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [ placeholders = {
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement) modality: [
for p in placeholders _PlaceholderInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
replacement=p.replacement,
) for p in ps
] ]
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders return token_ids, text, placeholders
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)

View File

@ -780,15 +780,18 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_pixtral_hf_image_tokens(self.vision_config) return get_max_pixtral_hf_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_pixtral_hf_patch_grid_length( return get_pixtral_hf_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
class PixtralHFMLP(nn.Module): class PixtralHFMLP(nn.Module):

View File

@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1 max_output_lengths = (max_source_positions - 2) // 2 + 1
@ -184,15 +184,16 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
] ]
def _always_apply_prompt_replacements(self) -> bool: def _always_apply_prompt_replacements(self) -> bool:
# HF never applies prompt replacements, so we have to do it ourselves # HF never applies prompt replacements, so we have to do it ourselves.
# _find_placeholders may incorrectly think that HF has already performed # NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# processing for multi-audio input when the input audios are short # has already performed processing for multi-audio input when the input
# (the corresponding placeholders may take up fewer tokens than # audios are short (the corresponding placeholders may take up fewer
# the number of audio items) # tokens than the number of audio items)
return True return True
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()

View File

@ -56,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem) NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
@ -641,58 +642,6 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params return loaded_params
# === Vision input helpers === #
def _get_vision_info(
vision_config: Qwen2VLVisionConfig,
height: int,
width: int,
min_pixels: int,
max_pixels: int,
*,
do_resize: bool = True,
modality: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
else:
resized_height, resized_width = height, width
if modality == "image":
grid_t = mm_count
elif modality == "video":
grid_t = max(mm_count // temporal_patch_size, 1)
else:
raise ValueError(f"Modality {modality} is not supported")
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = vision_tokens // (merge_size**2)
return resized_height, resized_width, llm_num_vision_tokens
def _get_image_processor(hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]): dict[str, torch.Tensor]]):
@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def _get_max_mm_tokens(self, modality: str) -> int: def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
) -> tuple[ImageSize, int]:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig) hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info( if do_resize:
vision_config, resized_height, resized_width = smart_resize(
height=9999999, height=image_height,
width=9999999, width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels, min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels, max_pixels=image_processor.max_pixels,
modality=modality,
) )
return max_llm_image_tokens preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width,
height=image_height)
grid_t = max(num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _get_dummy_image_size(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_size
def _get_max_image_tokens(self) -> int:
_, max_image_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
)
return max_image_tokens
def _get_max_video_tokens(self, num_frames: int) -> int:
_, max_video_tokens = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=num_frames,
)
return max_video_tokens
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return { return {
"image": self._get_max_mm_tokens("image"), "image": max_image_tokens,
"video": self._get_max_mm_tokens("video"), "video": max_video_tokens,
} }
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser() return Qwen2MultiModalDataParser()
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def _get_hf_processor( def _get_hf_processor(
self, self,
*, *,
@ -797,7 +825,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor: ) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
if min_pixels: if min_pixels:
image_processor.min_pixels = min_pixels image_processor.min_pixels = min_pixels
@ -818,7 +846,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor) image_processor = self._get_image_processor(hf_processor)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered # image_token and video_token registered
@ -873,32 +901,35 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video"),
) )
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize(
height=9999999,
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
target_width, target_height = self._get_dummy_image_size()
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=resized_width, self._get_dummy_images(width=target_width,
height=resized_height, height=target_height,
num_images=num_images) num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
} }
return ProcessorInputs( return ProcessorInputs(
prompt_text=image_token * num_images, prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data, mm_data=mm_data,
) )

View File

@ -171,15 +171,18 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config) return get_max_siglip_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_siglip_patch_grid_length( return get_siglip_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module): class SiglipVisionEmbeddings(nn.Module):

View File

@ -6,7 +6,6 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@ -31,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length * max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND) _AUDIO_TOKENS_PER_SECOND)
@ -103,6 +101,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_data = dict(mm_data) mm_data = dict(mm_data)
audios = mm_data.pop("audios", []) audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
if not audios: if not audios:
return super()._call_hf_processor( return super()._call_hf_processor(
@ -117,9 +116,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
) )
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs, # Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one # therefore we need to input text and audio one by one
audio_features, audio_token_len = [], [] audio_features, audio_token_len = [], []
@ -177,8 +173,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()

View File

@ -1,8 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, TypeVar from typing import Final, Generic, Optional, Protocol, TypeVar
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ProcessingCache)
_C = TypeVar("_C", bound=PretrainedConfig) _C = TypeVar("_C", bound=PretrainedConfig)
@ -27,11 +31,15 @@ class VisionEncoderInfo(ABC, Generic[_C]):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_num_patches(self) -> int: def get_image_size(self) -> int:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_image_size(self) -> int: def get_patch_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_grid_length(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> VisionLanguageConfig:
raise NotImplementedError

View File

@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None: def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video") super().__init__(data, "video")
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems): class VideoEmbeddingItems(EmbeddingItems):

View File

@ -16,7 +16,8 @@ from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalFieldConfig, from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
@ -69,19 +70,6 @@ def _cached_encode(
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
def _decode(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _cached_decode( def _cached_decode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
@ -89,7 +77,7 @@ def _cached_decode(
*, *,
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
) -> str: ) -> str:
return _decode(tokenizer, return decode_tokens(tokenizer,
list(token_ids), list(token_ids),
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens)
@ -269,8 +257,10 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
return self.match.end() return self.match.end()
class _PlaceholderInfo(NamedTuple): @dataclass
class _PlaceholderInfo:
modality: str modality: str
item_idx: int
start_idx: int start_idx: int
replacement: list[int] replacement: list[int]
@ -311,12 +301,14 @@ def find_text_matches(
def _resolve_matches( def _resolve_matches(
prompt: _PromptSeq, prompt: _PromptSeq,
matches: Sequence[_PromptReplacementMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
) -> list[_PromptReplacementMatch]: ) -> list[_PromptReplacementMatch]:
""" """
Resolve :code:`matches` to ensure that there are no overlapping matches, Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones. and sort them such that earlier matches take priority over later ones.
""" """
matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[_PromptReplacementMatch]] = [None seen_matches: list[Optional[_PromptReplacementMatch]] = [None
] * len(prompt) ] * len(prompt)
@ -334,14 +326,15 @@ def _resolve_matches(
def _replace_matches( def _replace_matches(
prompt: _S, prompt: _S,
matches: Sequence[_PromptReplacementMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[_S]: ) -> list[_S]:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
out_seqs = list[_S]() out_seqs = list[_S]()
prev_end_idx = 0 prev_end_idx = 0
next_idx_by_modality = defaultdict[str, int](lambda: 0) next_idx_by_modality = defaultdict[str, int](lambda: 0)
for match in _resolve_matches(prompt, matches): for match in _resolve_matches(prompt, mm_matches):
modality = match.modality modality = match.modality
item_idx = next_idx_by_modality[modality] item_idx = next_idx_by_modality[modality]
@ -371,28 +364,28 @@ def _replace_matches(
def replace_token_matches( def replace_token_matches(
prompt: list[int], prompt: list[int],
matches: Sequence[_PromptReplacementTokenMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[int]: ) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`.""" """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
if not matches: if not mm_matches:
return prompt return prompt
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts) token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
return flatten_2d_lists(token_id_seqs) return flatten_2d_lists(token_id_seqs)
def replace_text_matches( def replace_text_matches(
prompt: str, prompt: str,
matches: Sequence[_PromptReplacementTextMatch], mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> str: ) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`.""" """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
if not matches: if not mm_matches:
return prompt return prompt
texts = _replace_matches(prompt, matches, mm_item_counts) texts = _replace_matches(prompt, mm_matches, mm_item_counts)
return "".join(texts) return "".join(texts)
@ -407,14 +400,14 @@ def _iter_modality_placeholders(
return return
prompt_len = len(prompt) prompt_len = len(prompt)
item_index = 0 item_idx = 0
start_idx = 0 start_idx = 0
while start_idx < prompt_len: while start_idx < prompt_len:
found = False found = False
for repl_info in modality_repls: for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_index) replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids repl_tokens = replacement.token_ids
repl_len = len(repl_tokens) repl_len = len(repl_tokens)
end_idx = start_idx + repl_len end_idx = start_idx + repl_len
@ -425,12 +418,13 @@ def _iter_modality_placeholders(
if prompt[start_idx:end_idx] == repl_tokens: if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo( yield _PlaceholderInfo(
modality=modality, modality=modality,
item_idx=item_idx,
start_idx=start_idx, start_idx=start_idx,
replacement=repl_tokens, replacement=repl_tokens,
) )
item_index += 1 item_idx += 1
if item_index >= modal_item_count: if item_idx >= modal_item_count:
return return
# Exclude overlapping matches # Exclude overlapping matches
@ -442,28 +436,36 @@ def _iter_modality_placeholders(
start_idx += 1 start_idx += 1
def iter_placeholders( def _iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]: ) -> Iterable[_PlaceholderInfo]:
""" """
Yield each set of placeholder tokens found in :code:`prompt`. For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
Note that empty matches are ignored. Note that empty matches are ignored.
""" """
repls_by_modality = dict(full_groupby_modality(prompt_repls))
for modality, modal_item_count in mm_item_counts.items(): for modality, modal_item_count in mm_item_counts.items():
if modality in repls_by_modality: if modality in mm_prompt_repls:
yield from _iter_modality_placeholders( yield from _iter_modality_placeholders(
prompt, prompt,
modality, modality,
repls_by_modality[modality], mm_prompt_repls[modality],
modal_item_count, modal_item_count,
) )
def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it))
@dataclass @dataclass
class ProcessorInputs: class ProcessorInputs:
"""Keyword arguments to :meth:`BaseMultiModalProcessor`.""" """Keyword arguments to :meth:`BaseMultiModalProcessor`."""
@ -620,7 +622,7 @@ class BaseMultiModalProcessor(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
""" """
Get the maximum possible number of tokens per data item Get the maximum possible number of tokens per data item
for each modality. for each modality.
@ -703,14 +705,14 @@ class BaseMultiModalProcessor(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def _find_placeholders( def _find_mm_placeholders(
self, self,
all_prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[_PlaceholderInfo]: ) -> Mapping[str, list[_PlaceholderInfo]]:
return list( return find_mm_placeholders(mm_prompt_repls, new_token_ids,
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) mm_item_counts)
def _get_hf_mm_data( def _get_hf_mm_data(
self, self,
@ -797,7 +799,10 @@ class BaseMultiModalProcessor(ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding # Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text # multi-modal tokens to be in the prompt text
dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts) dummy_inputs = self._get_dummy_processor_inputs(
self.ctx.model_config.max_model_len,
mm_missing_counts,
)
_, mm_missing_kwargs = self._apply_hf_processor( _, mm_missing_kwargs = self._apply_hf_processor(
prompt_text=dummy_inputs.prompt_text, prompt_text=dummy_inputs.prompt_text,
@ -889,50 +894,44 @@ class BaseMultiModalProcessor(ABC):
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_all_counts()
for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count):
try:
mm_kwargs.get_item(modality, item_idx)
except Exception as e:
# Make it easy to set a breakpoint in the debugger
raise e
return prompt_ids, mm_kwargs return prompt_ids, mm_kwargs
def _bind_prompt_replacements( def _bind_and_group_repls(
self, self,
prompt_repls: list[PromptReplacement], prompt_repls: list[PromptReplacement],
) -> list[_BoundPromptReplacement]: ) -> dict[str, list[_BoundPromptReplacement]]:
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls] it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it))
def _always_apply_prompt_replacements(self) -> bool: def _always_apply_prompt_replacements(self) -> bool:
""" """
A flag which can be overridden so that A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we :meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via :meth:`_find_placeholders`. detect that HF has performed processing via
:meth:`_find_placeholders_by_modality`.
This is useful in cases where :meth:`_find_placeholders` cannot be This is useful in cases where :meth:`_find_placeholders_by_modality`
reliably used to detect whether HF has performed processing or not. cannot be reliably used to detect whether HF has performed processing.
""" """
return False return False
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement], mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]: ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
token_matches = find_token_matches(token_ids, prompt_repls) mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
mm_match_counts = { mm_match_counts = {
modality: len(matches) modality: len(matches)
for modality, matches in full_groupby_modality(token_matches) for modality, matches in mm_token_matches.items()
} }
# If the search text does not represent a special token, # If the search text does not represent a special token,
@ -951,32 +950,92 @@ class BaseMultiModalProcessor(ABC):
): # yapf: disable ): # yapf: disable
token_ids = replace_token_matches( token_ids = replace_token_matches(
token_ids, token_ids,
token_matches, mm_token_matches,
mm_item_counts, mm_item_counts,
) )
text = _decode(tokenizer, token_ids) text = decode_tokens(tokenizer, token_ids)
matched_repls = [match.prompt_repl for match in token_matches] matched_repls = {
modality: [match.prompt_repl for match in token_matches]
for modality, token_matches in mm_token_matches.items()
}
else: else:
text = _decode(tokenizer, token_ids) text = decode_tokens(tokenizer, token_ids)
text_matches = find_text_matches(text, prompt_repls) mm_text_matches = {
modality: find_text_matches(text, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
text = replace_text_matches( text = replace_text_matches(
text, text,
text_matches, mm_text_matches,
mm_item_counts, mm_item_counts,
) )
token_ids = encode_tokens(tokenizer, token_ids = encode_tokens(tokenizer,
text, text,
add_special_tokens=False) add_special_tokens=False)
matched_repls = [match.prompt_repl for match in text_matches] matched_repls = {
modality: [match.prompt_repl for match in token_matches]
for modality, token_matches in mm_text_matches.items()
}
placeholders = self._find_placeholders(matched_repls, token_ids, placeholders = self._find_mm_placeholders(
mm_item_counts) matched_repls,
token_ids,
mm_item_counts,
)
return token_ids, text, placeholders return token_ids, text, placeholders
def _validate_mm_kwargs(
self,
mm_kwargs: MultiModalKwargs,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
if modality in mm_kwargs.modalities:
items = mm_kwargs.get_items(modality)
else:
items = []
if len(items) != item_count:
raise RuntimeError(
f"Expected there to be {item_count} {modality} items in "
f"keyword arguments corresponding to {item_count} "
f"{modality} data items, but only found {len(items)}! "
"There is likely a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_mm_fields_config`).")
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
mm_item_counts: Mapping[str, int],
*,
allow_missing: bool = False,
) -> Mapping[str, int]:
missing_repl_counts = dict[str, int]()
for modality, item_count in mm_item_counts.items():
placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count and not allow_missing:
raise RuntimeError(
f"Expected there to be {item_count} prompt replacements "
f"corresponding to {item_count} {modality} items, but only "
f"found {len(placeholders)} prompt replacements! Either "
"the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`).")
missing_repl_counts[modality] = item_count - len(placeholders)
return missing_repl_counts
def apply( def apply(
self, self,
prompt_text: str, prompt_text: str,
@ -1009,56 +1068,69 @@ class BaseMultiModalProcessor(ABC):
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls) mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
hf_mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls,
prompt_ids,
mm_item_counts,
)
if self._always_apply_prompt_replacements():
mm_missing_repl_counts = mm_item_counts
mm_missing_repls = dict(mm_prompt_repls)
else:
mm_missing_repl_counts = self._validate_mm_placeholders(
hf_mm_placeholders,
mm_item_counts,
allow_missing=True,
)
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0:
mm_missing_repls[modality] = []
elif missing_repl_count == mm_item_counts.get(modality, 0):
mm_missing_repls[modality] = mm_prompt_repls[modality]
else:
raise ValueError("Partial prompt replacement within "
f"{modality=} is not supported")
# If HF processor already inserts placeholder tokens, # If HF processor already inserts placeholder tokens,
# there is no need for us to insert them # there is no need for us to insert them
mm_item_counts = mm_items.get_all_counts() if all(len(repls) == 0 for repls in mm_missing_repls.items()):
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts)
if all_placeholders and not self._always_apply_prompt_replacements():
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
prompt_text = _decode(tokenizer, prompt_ids) prompt_text = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else: else:
( (
prompt_ids, prompt_ids,
prompt_text, prompt_text,
all_placeholders, missing_mm_placeholders,
) = self._apply_prompt_replacements( ) = self._apply_prompt_replacements(
prompt_ids, prompt_ids,
prompt_repls, mm_missing_repls,
mm_item_counts, mm_missing_repl_counts,
) )
mm_placeholders = dict[str, list[PlaceholderRange]]() mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
err_suffix = ("This suggests a problem with your implementation of "
"the merged multi-modal processor for this model, "
"particularly in the `_get_prompt_replacements` method.")
for modality, placeholders in full_groupby_modality(all_placeholders): self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
if modality not in mm_items:
raise AssertionError(
f"Expected no placeholders for {modality=}, "
f"but found {placeholders=}. Input items: {mm_items}"
f"\n{err_suffix}")
if len(placeholders) != len(mm_items[modality]): mm_placeholder_ranges = {
raise AssertionError( modality: [item.to_range() for item in placeholders]
f"Expected length of {placeholders=} for {modality=} " for modality, placeholders in mm_placeholders.items()
f"to equal that of input items: {mm_items[modality]}" }
f"\n{err_suffix}")
mm_placeholders[modality] = [
item.to_range() for item in placeholders
]
return MultiModalInputsV2( return MultiModalInputsV2(
type="multimodal", type="multimodal",
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders, mm_placeholders=mm_placeholder_ranges,
) )
def _get_dummy_audios( def _get_dummy_audios(
@ -1092,8 +1164,9 @@ class BaseMultiModalProcessor(ABC):
return [video] * num_videos return [video] * num_videos
@abstractmethod @abstractmethod
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
@ -1121,12 +1194,25 @@ class BaseMultiModalProcessor(ABC):
return mm_limits return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts)
return self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData: def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
mm_counts = self._get_and_validate_dummy_mm_counts() mm_counts = self._get_and_validate_dummy_mm_counts()
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item() mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys(): if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError( raise AssertionError(
"The keys returned by `get_supported_mm_limits`" "The keys returned by `get_supported_mm_limits`"
@ -1134,13 +1220,7 @@ class BaseMultiModalProcessor(ABC):
"returned by `get_mm_max_tokens_per_item` " "returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})") f"({set(mm_max_tokens_per_item.keys())})")
processor_inputs = self._get_dummy_mm_inputs(mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs = self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
@ -1171,6 +1251,12 @@ class BaseMultiModalProcessor(ABC):
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality) total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData( return DummyData(

View File

@ -223,7 +223,8 @@ class MultiModalRegistry:
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config, tokenizer)
return processor.get_mm_max_tokens_per_item() seq_len = model_config.max_model_len
return processor.get_mm_max_tokens_per_item(seq_len)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)

View File

@ -21,6 +21,19 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer] MistralTokenizer]
def decode_tokens(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def encode_tokens( def encode_tokens(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
text: str, text: str,