mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 04:29:08 +08:00
[3/N] Support and implement merged input processor for LLaVA model (#10676)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
acf092d348
commit
955fa9533a
@ -2,7 +2,7 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
|
from transformers import LlavaNextImageProcessor
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.multimodal import MultiModalRegistry
|
from vllm.multimodal import MultiModalRegistry
|
||||||
@ -14,49 +14,6 @@ def mm_registry():
|
|||||||
return MultiModalRegistry()
|
return MultiModalRegistry()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
|
||||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
|
||||||
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
|
|
||||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
|
||||||
|
|
||||||
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
|
|
||||||
assert isinstance(hf_processor, CLIPImageProcessor)
|
|
||||||
|
|
||||||
model_config = ModelConfig(
|
|
||||||
model=MODEL_NAME,
|
|
||||||
task="auto",
|
|
||||||
tokenizer=MODEL_NAME,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=False,
|
|
||||||
seed=0,
|
|
||||||
dtype=dtype,
|
|
||||||
revision=None,
|
|
||||||
limit_mm_per_prompt={"image": 1},
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_registry.init_mm_limits_per_prompt(model_config)
|
|
||||||
|
|
||||||
for asset in image_assets:
|
|
||||||
image = rescale_image_size(asset.pil_image, size_factor)
|
|
||||||
|
|
||||||
hf_result = hf_processor.preprocess(
|
|
||||||
image,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
vllm_result = mm_registry.map_input(
|
|
||||||
model_config,
|
|
||||||
{"image": image},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert hf_result.keys() == vllm_result.keys()
|
|
||||||
for key, hf_tensor in hf_result.items():
|
|
||||||
hf_arr: np.ndarray = hf_tensor.numpy()
|
|
||||||
vllm_arr: np.ndarray = vllm_result[key].numpy()
|
|
||||||
|
|
||||||
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
|
|
||||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||||
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
||||||
@ -107,7 +64,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
|||||||
(2, 1, False), (2, 2, True)],
|
(2, 1, False), (2, 2, True)],
|
||||||
)
|
)
|
||||||
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
||||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
@ -138,7 +95,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
|||||||
# NOTE: We don't test zero images since the HF processor doesn't support it
|
# NOTE: We don't test zero images since the HF processor doesn't support it
|
||||||
@pytest.mark.parametrize("num_images", [1, 2])
|
@pytest.mark.parametrize("num_images", [1, 2])
|
||||||
def test_image_mapper_multi(image_assets, mm_registry, num_images):
|
def test_image_mapper_multi(image_assets, mm_registry, num_images):
|
||||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
|
|||||||
@ -3,50 +3,15 @@ from typing import cast
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.multimodal.processing import (PromptReplacement, find_text_matches,
|
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
|
||||||
find_token_matches, iter_token_matches,
|
find_text_matches, find_token_matches,
|
||||||
iter_token_runs, replace_text_matches)
|
iter_placeholders, iter_token_matches,
|
||||||
|
replace_text_matches,
|
||||||
|
replace_token_matches)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import full_groupby
|
from vllm.utils import full_groupby
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("token_ids", "expected"),
|
|
||||||
[
|
|
||||||
([], []),
|
|
||||||
(
|
|
||||||
[32000, 32000, 32000],
|
|
||||||
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
|
||||||
[
|
|
||||||
{ "token_id": 9833, "start_idx": 0, "length": 1 },
|
|
||||||
{ "token_id": 28747, "start_idx": 1, "length": 1 },
|
|
||||||
{ "token_id": 32000, "start_idx": 2, "length": 3 },
|
|
||||||
{ "token_id": 9833, "start_idx": 5, "length": 1 },
|
|
||||||
{ "token_id": 28747, "start_idx": 6, "length": 1 },
|
|
||||||
{ "token_id": 32000, "start_idx": 7, "length": 2 },
|
|
||||||
{ "token_id": 918, "start_idx": 9, "length": 1 },
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
# yapf: enable
|
|
||||||
def test_iter_token_runs(token_ids, expected):
|
|
||||||
result = list(iter_token_runs(token_ids))
|
|
||||||
|
|
||||||
# Only displayed on error
|
|
||||||
print("result:", result)
|
|
||||||
|
|
||||||
# Manually constructed results
|
|
||||||
assert [item._asdict() for item in result] == expected
|
|
||||||
|
|
||||||
# Invariants
|
|
||||||
assert sum(run_info.length for run_info in result) == len(token_ids)
|
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("token_ids", "match_ids", "expected"),
|
("token_ids", "match_ids", "expected"),
|
||||||
@ -170,13 +135,11 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
|||||||
# Should not be used since there is nothing to convert to token IDs
|
# Should not be used since there is nothing to convert to token IDs
|
||||||
mock_tokenizer = cast(AnyTokenizer, object())
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
result = find_token_matches(
|
prompt_repls = [
|
||||||
prompt,
|
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||||
[
|
for key, target in target_by_key.items()
|
||||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
]
|
||||||
for key, target in target_by_key.items()
|
result = find_token_matches(prompt, prompt_repls)
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("result:", result)
|
print("result:", result)
|
||||||
@ -279,13 +242,11 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
|||||||
# Should not be used since there is nothing to convert to text
|
# Should not be used since there is nothing to convert to text
|
||||||
mock_tokenizer = cast(AnyTokenizer, object())
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
result = find_text_matches(
|
prompt_repls = [
|
||||||
prompt,
|
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||||
[
|
for key, target in target_by_key.items()
|
||||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
]
|
||||||
for key, target in target_by_key.items()
|
result = find_text_matches(prompt, prompt_repls)
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("result:", result)
|
print("result:", result)
|
||||||
@ -303,7 +264,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
|||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"),
|
("prompt", "target_by_key", "repl_by_key"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"Image:<image>Image:<image><image>!",
|
"Image:<image>Image:<image><image>!",
|
||||||
@ -322,49 +283,201 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
|||||||
# Test multiple repl_count
|
# Test multiple repl_count
|
||||||
"pattern_3": ("?", 2),
|
"pattern_3": ("?", 2),
|
||||||
},
|
},
|
||||||
{
|
|
||||||
# Test no replacement
|
|
||||||
0: "Image:<image>Image:<image><image>!",
|
|
||||||
# Test single replacement
|
|
||||||
1: "<image><image>Image:<image><image>??",
|
|
||||||
# Test repeated replacement
|
|
||||||
2: "<image><image><image><image><image>??",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("mm_count", "expected"),
|
||||||
|
[
|
||||||
|
(0, "Image:<image>Image:<image><image>!"),
|
||||||
|
(1, "<image><image>Image:<image><image>??"),
|
||||||
|
(2, "<image><image><image><image><image>??"),
|
||||||
|
]
|
||||||
|
)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
def test_find_replace_text(
|
def test_find_replace_text(
|
||||||
prompt,
|
prompt,
|
||||||
target_by_key,
|
target_by_key,
|
||||||
repl_by_key,
|
repl_by_key,
|
||||||
expected_by_mm_count,
|
mm_count,
|
||||||
|
expected,
|
||||||
):
|
):
|
||||||
# Should not be used since there is nothing to convert to text
|
# Should not be used since there is nothing to convert to text
|
||||||
mock_tokenizer = cast(AnyTokenizer, object())
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
matches = find_text_matches(
|
prompt_repls = [
|
||||||
|
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
|
||||||
|
for key, target in target_by_key.items()
|
||||||
|
]
|
||||||
|
matches = find_text_matches(prompt, prompt_repls)
|
||||||
|
|
||||||
|
result = replace_text_matches(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
matches,
|
||||||
PromptReplacement(target, *repl_by_key[key]) \
|
{key: list(range(mm_count))
|
||||||
.bind(key, mock_tokenizer)
|
for key in repl_by_key},
|
||||||
for key, target in target_by_key.items()
|
BatchFeature(),
|
||||||
],
|
|
||||||
)
|
)
|
||||||
result_by_mm_count = {
|
|
||||||
mm_count: replace_text_matches(
|
|
||||||
prompt,
|
|
||||||
matches,
|
|
||||||
{key: list(range(mm_count))
|
|
||||||
for key in repl_by_key},
|
|
||||||
BatchFeature(),
|
|
||||||
)
|
|
||||||
for mm_count in expected_by_mm_count
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("matches:", matches)
|
print("matches:", matches)
|
||||||
print("result_by_mm_count:", result_by_mm_count)
|
print("result:", result)
|
||||||
|
|
||||||
# Manually constructed results
|
# Manually constructed results
|
||||||
assert result_by_mm_count == expected_by_mm_count
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("prompt", "target_by_key", "repl_by_key"),
|
||||||
|
[
|
||||||
|
# Tokenized test cases of `test_find_replace_text`
|
||||||
|
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
(
|
||||||
|
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
{
|
||||||
|
# We use `<image>` before `Image:` to test matches that
|
||||||
|
# occur out of order
|
||||||
|
"pattern_1": [32000],
|
||||||
|
"pattern_2": [9833, 28747],
|
||||||
|
"pattern_3": [918],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Test whether target is confused with repl_unit
|
||||||
|
"pattern_1": ([32000, 32000], 1),
|
||||||
|
# Test empty repl_unit
|
||||||
|
"pattern_2": ([], 1),
|
||||||
|
# Test multiple repl_count
|
||||||
|
"pattern_3": ([1550], 2),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("mm_count", "expected"),
|
||||||
|
[
|
||||||
|
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
|
||||||
|
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
|
||||||
|
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_find_replace_tokens(
|
||||||
|
prompt,
|
||||||
|
target_by_key,
|
||||||
|
repl_by_key,
|
||||||
|
mm_count,
|
||||||
|
expected,
|
||||||
|
):
|
||||||
|
# Should not be used since there is nothing to convert to tokens
|
||||||
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
|
prompt_repls = [
|
||||||
|
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
|
||||||
|
for key, target in target_by_key.items()
|
||||||
|
]
|
||||||
|
matches = find_token_matches(prompt, prompt_repls)
|
||||||
|
|
||||||
|
result = replace_token_matches(
|
||||||
|
prompt,
|
||||||
|
matches,
|
||||||
|
{key: list(range(mm_count))
|
||||||
|
for key in repl_by_key},
|
||||||
|
BatchFeature(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only displayed on error
|
||||||
|
print("matches:", matches)
|
||||||
|
print("result:", result)
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"repl_by_key",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"pattern_1": ([32000, 32000], 1),
|
||||||
|
"pattern_2": ([], 1),
|
||||||
|
"pattern_3": ([1550], 2),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("prompt", "expected"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[
|
||||||
|
_PlaceholderInfo(
|
||||||
|
modality="pattern_1",
|
||||||
|
start_idx=6,
|
||||||
|
unit=[32000, 32000],
|
||||||
|
unit_count=1,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
|
||||||
|
[
|
||||||
|
_PlaceholderInfo(
|
||||||
|
modality="pattern_1",
|
||||||
|
start_idx=1,
|
||||||
|
unit=[32000, 32000],
|
||||||
|
unit_count=1,
|
||||||
|
),
|
||||||
|
_PlaceholderInfo(
|
||||||
|
modality="pattern_1",
|
||||||
|
start_idx=5,
|
||||||
|
unit=[32000, 32000],
|
||||||
|
unit_count=1,
|
||||||
|
),
|
||||||
|
_PlaceholderInfo(
|
||||||
|
modality="pattern_3",
|
||||||
|
start_idx=7,
|
||||||
|
unit=[1550],
|
||||||
|
unit_count=2,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
|
||||||
|
[
|
||||||
|
_PlaceholderInfo(
|
||||||
|
modality="pattern_1",
|
||||||
|
start_idx=1,
|
||||||
|
unit=[32000, 32000],
|
||||||
|
unit_count=2,
|
||||||
|
),
|
||||||
|
_PlaceholderInfo(
|
||||||
|
modality="pattern_3",
|
||||||
|
start_idx=6,
|
||||||
|
unit=[1550],
|
||||||
|
unit_count=2,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_iter_placeholders(
|
||||||
|
repl_by_key,
|
||||||
|
prompt,
|
||||||
|
expected,
|
||||||
|
):
|
||||||
|
# Should not be used since there is nothing to convert to tokens
|
||||||
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
|
prompt_repls = [
|
||||||
|
PromptReplacement([], *repl).bind(key, mock_tokenizer)
|
||||||
|
for key, repl in repl_by_key.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
result = list(iter_placeholders(prompt_repls, prompt))
|
||||||
|
|
||||||
|
# Only displayed on error
|
||||||
|
print("result:", result)
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
assert result == expected
|
||||||
|
|||||||
@ -2,19 +2,17 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
|
||||||
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
||||||
dummy_data_for_llava,
|
create_metadata_for_llava,
|
||||||
get_max_llava_image_tokens,
|
dummy_mm_kwargs_for_llava,
|
||||||
input_processor_for_llava)
|
get_max_llava_image_tokens)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava,
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
dummy_mm_kwargs_for_llava)
|
||||||
class MyLlava(LlavaForConditionalGeneration):
|
class MyLlava(LlavaForConditionalGeneration):
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
|||||||
@ -232,19 +232,35 @@ class InputRegistry:
|
|||||||
"""
|
"""
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.model_executor.model_loader import get_model_architecture
|
from vllm.model_executor.model_loader import get_model_architecture
|
||||||
|
from vllm.multimodal import MultiModalKwargs
|
||||||
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
model_cls, _ = get_model_architecture(model_config)
|
if mm_registry.has_processor(model_config):
|
||||||
if is_encoder_data:
|
tokenizer = cached_get_tokenizer(
|
||||||
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
|
model_config.tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
processor = mm_registry.create_processor(model_config, tokenizer)
|
||||||
|
|
||||||
|
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||||
|
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
|
||||||
|
model_config)
|
||||||
|
|
||||||
|
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
|
||||||
|
mm_max_tokens)
|
||||||
else:
|
else:
|
||||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
if is_encoder_data:
|
||||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
|
||||||
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
else:
|
||||||
|
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||||
|
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||||
|
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
|
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
||||||
|
|
||||||
dummy_data = dummy_factory(InputContext(model_config), seq_len,
|
dummy_data = dummy_factory(InputContext(model_config), seq_len,
|
||||||
_MultiModalCounts(mm_counts),
|
_MultiModalCounts(mm_counts),
|
||||||
**mm_processor_kwargs)
|
**mm_processor_kwargs)
|
||||||
|
|
||||||
# Having more tokens is over-conservative but otherwise fine
|
# Having more tokens is over-conservative but otherwise fine
|
||||||
num_tokens = dummy_data.seq_data.prompt_token_ids
|
num_tokens = dummy_data.seq_data.prompt_token_ids
|
||||||
@ -257,7 +273,9 @@ class InputRegistry:
|
|||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||||
f"but found {len(num_tokens)} tokens instead.")
|
f"but found {len(num_tokens)} tokens instead.")
|
||||||
if dummy_data.multi_modal_data is not None:
|
|
||||||
|
if (dummy_data.multi_modal_data is not None and
|
||||||
|
not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
|
||||||
for k, v in dummy_data.multi_modal_data.items():
|
for k, v in dummy_data.multi_modal_data.items():
|
||||||
num_items = len(v) if isinstance(v, list) else 1
|
num_items = len(v) if isinstance(v, list) else 1
|
||||||
num_expected = mm_counts[k]
|
num_expected = mm_counts[k]
|
||||||
|
|||||||
@ -1,17 +1,19 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from types import MethodType
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||||
Tuple, TypedDict, Union)
|
Tuple, TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL.Image import Image
|
||||||
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||||
PretrainedConfig, SiglipVisionConfig)
|
PixtralVisionConfig, PretrainedConfig,
|
||||||
|
ProcessorMixin, SiglipVisionConfig)
|
||||||
|
from transformers.models.pixtral import PixtralProcessor
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import InputContext
|
||||||
InputContext)
|
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -19,21 +21,20 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||||
|
from vllm.multimodal.processing import (InputProcessingContext,
|
||||||
|
ModalityProcessingMetadata,
|
||||||
|
MultiModalProcessingMetadata,
|
||||||
|
MultiModalProcessor, PromptReplacement)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
|
||||||
|
|
||||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
get_max_clip_image_tokens)
|
||||||
input_processor_for_clip)
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||||
dummy_seq_data_for_pixtral_hf,
|
get_max_pixtral_hf_image_tokens)
|
||||||
get_max_pixtral_hf_image_tokens,
|
|
||||||
input_processor_for_pixtral_hf)
|
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
get_max_siglip_image_tokens)
|
||||||
input_processor_for_siglip)
|
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
|
||||||
@ -113,102 +114,86 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
|||||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
|
||||||
mm_counts: Mapping[str, int]):
|
mm_counts: Mapping[str, int]):
|
||||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
num_images = mm_counts["image"]
|
num_images = mm_counts["image"]
|
||||||
|
|
||||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
seq_data, ranges = dummy_seq_data_for_clip(
|
data = dummy_image_for_clip(vision_config, num_images)
|
||||||
vision_config,
|
|
||||||
seq_len,
|
|
||||||
num_images,
|
|
||||||
image_token_id=hf_config.image_token_index,
|
|
||||||
image_feature_size_override=image_feature_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_data = dummy_image_for_clip(vision_config, num_images)
|
|
||||||
return DummyData(seq_data, mm_data, ranges)
|
|
||||||
elif isinstance(vision_config, SiglipVisionConfig):
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
data = dummy_image_for_siglip(vision_config, num_images)
|
||||||
vision_config,
|
|
||||||
seq_len,
|
|
||||||
num_images,
|
|
||||||
image_token_id=hf_config.image_token_index,
|
|
||||||
image_feature_size_override=image_feature_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
|
||||||
return DummyData(seq_data, mm_data, ranges)
|
|
||||||
elif isinstance(vision_config, PixtralVisionConfig):
|
elif isinstance(vision_config, PixtralVisionConfig):
|
||||||
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
|
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||||
vision_config,
|
|
||||||
seq_len,
|
|
||||||
num_images,
|
|
||||||
image_token_id=hf_config.image_token_index,
|
|
||||||
image_feature_size_override=image_feature_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
|
||||||
return DummyData(seq_data, mm_data, ranges)
|
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|
||||||
multi_modal_data = inputs.get("multi_modal_data")
|
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
model_config = ctx.model_config
|
|
||||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
|
||||||
if isinstance(image_data, Image.Image):
|
|
||||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
|
||||||
elif is_list_of(image_data, Image.Image):
|
|
||||||
image_feature_size = [get_max_llava_image_tokens(ctx)
|
|
||||||
] * len(image_data)
|
|
||||||
elif isinstance(image_data, torch.Tensor):
|
|
||||||
num_images, image_feature_size, hidden_size = image_data.shape
|
|
||||||
elif is_list_of(image_data, torch.Tensor):
|
|
||||||
image_feature_size = [item.shape[1] for item in image_data]
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
hf_processor = ctx.get_hf_processor()
|
||||||
return input_processor_for_clip(
|
image_processor = hf_processor.image_processor # type: ignore
|
||||||
model_config,
|
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||||
vision_config,
|
is_pixtral = isinstance(hf_processor, PixtralProcessor)
|
||||||
inputs,
|
|
||||||
image_token_id=hf_config.image_token_index,
|
|
||||||
image_feature_size_override=image_feature_size,
|
|
||||||
)
|
|
||||||
elif isinstance(vision_config, SiglipVisionConfig):
|
|
||||||
return input_processor_for_siglip(
|
|
||||||
model_config,
|
|
||||||
vision_config,
|
|
||||||
inputs,
|
|
||||||
image_token_id=hf_config.image_token_index,
|
|
||||||
image_feature_size_override=image_feature_size,
|
|
||||||
)
|
|
||||||
elif isinstance(vision_config, PixtralVisionConfig):
|
|
||||||
# We ignore image_feature_size_override since we have non-uniform
|
|
||||||
# image sizes for Pixtral
|
|
||||||
return input_processor_for_pixtral_hf(
|
|
||||||
model_config,
|
|
||||||
vision_config,
|
|
||||||
inputs,
|
|
||||||
image_token_id=hf_config.image_token_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
return MultiModalKwargs(
|
||||||
raise NotImplementedError(msg)
|
**hf_inputs,
|
||||||
|
is_pixtral=torch.tensor(is_pixtral),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_metadata_for_llava(
|
||||||
|
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||||
|
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||||
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
|
def get_repl_count(
|
||||||
|
mm_items: list[Image],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
item_idx: int,
|
||||||
|
) -> int:
|
||||||
|
return get_max_llava_image_tokens(ctx)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image":
|
||||||
|
ModalityProcessingMetadata(prompt_repls=[
|
||||||
|
PromptReplacement(target=[image_token_id],
|
||||||
|
repl_unit=[image_token_id],
|
||||||
|
repl_count=get_repl_count),
|
||||||
|
]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaProcessor(MultiModalProcessor):
|
||||||
|
|
||||||
|
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||||
|
if getattr(hf_processor, "__is_patched__", False):
|
||||||
|
return # Already patched
|
||||||
|
|
||||||
|
image_processor = hf_processor.image_processor # type: ignore
|
||||||
|
orig_preprocess = image_processor.preprocess
|
||||||
|
|
||||||
|
def preprocess(__self, *args, **kwargs):
|
||||||
|
hf_inputs = orig_preprocess(*args, **kwargs)
|
||||||
|
hf_inputs["is_pixtral"] = torch.tensor(True)
|
||||||
|
return hf_inputs
|
||||||
|
|
||||||
|
image_processor.preprocess = MethodType(preprocess, image_processor)
|
||||||
|
|
||||||
|
hf_processor.__is_patched__ = True # type: ignore
|
||||||
|
|
||||||
|
def _get_hf_processor(self) -> ProcessorMixin:
|
||||||
|
hf_processor = self.ctx.get_hf_processor()
|
||||||
|
|
||||||
|
if isinstance(hf_processor, PixtralProcessor):
|
||||||
|
self._patch_pixtral_processor(hf_processor)
|
||||||
|
|
||||||
|
return hf_processor
|
||||||
|
|
||||||
|
def _get_dummy_mm_kwargs(
|
||||||
|
self,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalKwargs:
|
||||||
|
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
|
||||||
|
|
||||||
|
|
||||||
class LlavaLikeConfig(Protocol):
|
class LlavaLikeConfig(Protocol):
|
||||||
@ -291,10 +276,11 @@ def init_vision_tower_for_llava(
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
ctx=ctx,
|
||||||
|
metadata=create_metadata_for_llava(ctx),
|
||||||
|
))
|
||||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
@ -367,38 +353,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _validate_image_sizes(self, images: List[torch.Tensor],
|
|
||||||
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
||||||
if not isinstance(sizes, list):
|
|
||||||
sizes = [sizes]
|
|
||||||
|
|
||||||
total_images = sum(size.numel() // 2 for size in sizes)
|
|
||||||
if total_images != len(images):
|
|
||||||
raise ValueError("Mismatch in number of images. "
|
|
||||||
f"Expected {total_images}, got {len(images)}")
|
|
||||||
img_idx = 0
|
|
||||||
for size in sizes:
|
|
||||||
# Flatten the size tensor to a list of (height, width) pairs
|
|
||||||
size = size.view(-1, 2).tolist()
|
|
||||||
for expected_h, expected_w in size:
|
|
||||||
if img_idx >= len(images):
|
|
||||||
raise ValueError("Ran out of images before sizes. "
|
|
||||||
f"{img_idx} >= {len(images)}")
|
|
||||||
img = images[img_idx]
|
|
||||||
if img.shape[-2:] != (expected_h, expected_w):
|
|
||||||
raise ValueError(
|
|
||||||
"Image size mismatch. Expected "
|
|
||||||
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
|
|
||||||
if img.shape[-3] != 3:
|
|
||||||
raise ValueError("Image channel mismatch. Expected 3, "
|
|
||||||
f"got {img.shape[-3]}")
|
|
||||||
img_idx += 1
|
|
||||||
return images
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
image_sizes = kwargs.pop("image_sizes", None)
|
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None and image_embeds is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
@ -409,9 +367,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
# Case for models like PixtralHF that have dynamic image sizes
|
assert isinstance(is_pixtral, torch.Tensor)
|
||||||
# so we need to produce a list of tensors
|
if is_pixtral.any():
|
||||||
if image_sizes is not None:
|
|
||||||
images = pixel_values
|
images = pixel_values
|
||||||
|
|
||||||
def flatten_to_3d_tensors(item):
|
def flatten_to_3d_tensors(item):
|
||||||
@ -434,7 +391,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_image_sizes(images, image_sizes),
|
data=images,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
|
|||||||
@ -226,16 +226,16 @@ class MultiModalPlugin(ABC):
|
|||||||
"""
|
"""
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.model_executor.model_loader import get_model_architecture
|
from vllm.model_executor.model_loader import get_model_architecture
|
||||||
|
from vllm.model_executor.models import supports_multimodal
|
||||||
|
|
||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
|
|
||||||
if model_cls not in self._input_mappers:
|
if not supports_multimodal(model_cls):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
max_mm_tokens = self._max_mm_tokens.get(model_cls)
|
max_mm_tokens = self._max_mm_tokens.get(model_cls)
|
||||||
if max_mm_tokens is None:
|
if max_mm_tokens is None:
|
||||||
raise KeyError(f"No maximum number of multi-modal tokens is given "
|
return 0
|
||||||
f"for model class {model_cls.__name__} in {self}.")
|
|
||||||
|
|
||||||
if callable(max_mm_tokens):
|
if callable(max_mm_tokens):
|
||||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
@ -326,26 +326,47 @@ class MultiModalPlaceholderMap:
|
|||||||
src_ranges = []
|
src_ranges = []
|
||||||
dest_ranges = []
|
dest_ranges = []
|
||||||
"""
|
"""
|
||||||
if (not seq_group.multi_modal_data
|
seq_mm_data = seq_group.multi_modal_data
|
||||||
or not seq_group.multi_modal_placeholders):
|
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||||
return seq_group.multi_modal_data, {}
|
|
||||||
|
|
||||||
mm_data = {**seq_group.multi_modal_data}
|
if not seq_mm_data or not seq_mm_placeholders:
|
||||||
placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
|
return seq_mm_data, {}
|
||||||
|
|
||||||
|
# For merged processor, we directly use mm_kwargs as mm_data
|
||||||
|
if isinstance(seq_mm_data, MultiModalKwargs):
|
||||||
|
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||||
|
|
||||||
|
for modality, placeholders in seq_mm_placeholders.items():
|
||||||
|
placeholder_map = MultiModalPlaceholderMap()
|
||||||
|
|
||||||
|
if positions:
|
||||||
|
placeholder_map.append_items_from_seq_group(
|
||||||
|
positions,
|
||||||
|
# Dummy, since we don't care about intersecting items
|
||||||
|
[None] * len(placeholders),
|
||||||
|
placeholders,
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholder_maps[modality] = placeholder_map
|
||||||
|
|
||||||
|
return seq_mm_data, placeholder_maps
|
||||||
|
|
||||||
|
mm_data = {**seq_mm_data}
|
||||||
|
placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
|
||||||
MultiModalPlaceholderMap)
|
MultiModalPlaceholderMap)
|
||||||
|
|
||||||
for (
|
for modality, placeholders in seq_mm_placeholders.items():
|
||||||
modality,
|
|
||||||
placeholders,
|
|
||||||
) in seq_group.multi_modal_placeholders.items():
|
|
||||||
mm_items = mm_data.pop(modality)
|
mm_items = mm_data.pop(modality)
|
||||||
if not isinstance(mm_items, list):
|
if not isinstance(mm_items, list):
|
||||||
mm_items = [mm_items]
|
mm_items = [mm_items]
|
||||||
|
|
||||||
if positions:
|
if positions:
|
||||||
intersecting_items = placeholder_maps[
|
intersecting_items = placeholder_maps[modality] \
|
||||||
modality].append_items_from_seq_group(
|
.append_items_from_seq_group(
|
||||||
positions, mm_items, placeholders)
|
positions,
|
||||||
|
mm_items,
|
||||||
|
placeholders,
|
||||||
|
)
|
||||||
|
|
||||||
if intersecting_items:
|
if intersecting_items:
|
||||||
mm_data[modality] = intersecting_items
|
mm_data[modality] = intersecting_items
|
||||||
|
|||||||
@ -3,14 +3,13 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from itertools import groupby
|
|
||||||
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
|
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import torch
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature, ProcessorMixin
|
||||||
from typing_extensions import TypeAlias, TypedDict
|
from typing_extensions import TypeAlias, TypedDict
|
||||||
|
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import DummyData, InputProcessingContext
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
||||||
|
|
||||||
@ -256,63 +255,6 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
|
|||||||
return multi_data
|
return multi_data
|
||||||
|
|
||||||
|
|
||||||
class _TokenRun(NamedTuple):
|
|
||||||
token_id: int
|
|
||||||
|
|
||||||
start_idx: int
|
|
||||||
length: int
|
|
||||||
|
|
||||||
|
|
||||||
def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
|
|
||||||
"""
|
|
||||||
Yield the starting index and length of each run of tokens that are the same.
|
|
||||||
"""
|
|
||||||
start_idx = 0
|
|
||||||
|
|
||||||
for token_id, it in groupby(token_ids):
|
|
||||||
length = sum(1 for _ in it)
|
|
||||||
yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length)
|
|
||||||
|
|
||||||
start_idx += length
|
|
||||||
|
|
||||||
|
|
||||||
class _PlaceholderInfo(NamedTuple):
|
|
||||||
modality: str
|
|
||||||
offset: int
|
|
||||||
length: int
|
|
||||||
|
|
||||||
def to_range(self) -> PlaceholderRange:
|
|
||||||
return PlaceholderRange(offset=self.offset, length=self.length)
|
|
||||||
|
|
||||||
|
|
||||||
def iter_placeholders(
|
|
||||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
|
||||||
token_ids: list[int],
|
|
||||||
*,
|
|
||||||
min_placeholder_count: int,
|
|
||||||
) -> Iterable[_PlaceholderInfo]:
|
|
||||||
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
|
|
||||||
placeholder_ids_by_modality = {
|
|
||||||
modality: {
|
|
||||||
token_id
|
|
||||||
for prompt_repl in repls
|
|
||||||
for token_id in prompt_repl.repl_unit.token_ids
|
|
||||||
}
|
|
||||||
for modality, repls in full_groupby_modality(prompt_repls)
|
|
||||||
}
|
|
||||||
|
|
||||||
for run_info in iter_token_runs(token_ids):
|
|
||||||
if run_info.length > min_placeholder_count:
|
|
||||||
for (modality,
|
|
||||||
placeholder_ids) in placeholder_ids_by_modality.items():
|
|
||||||
if run_info.token_id in placeholder_ids:
|
|
||||||
yield _PlaceholderInfo(
|
|
||||||
modality=modality,
|
|
||||||
offset=run_info.start_idx,
|
|
||||||
length=run_info.length,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _TokenMatch(NamedTuple):
|
class _TokenMatch(NamedTuple):
|
||||||
start_idx: int
|
start_idx: int
|
||||||
end_idx: int
|
end_idx: int
|
||||||
@ -353,13 +295,9 @@ class _PromptReplacementMatch(ABC, Generic[_T, _S]):
|
|||||||
def end_idx(self) -> int:
|
def end_idx(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_repl(
|
def repl_unit(self) -> _S:
|
||||||
self,
|
|
||||||
mm_items: list[_T],
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
item_idx: int,
|
|
||||||
) -> _S:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -380,15 +318,9 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
|
|||||||
def end_idx(self) -> int:
|
def end_idx(self) -> int:
|
||||||
return self.match.end_idx
|
return self.match.end_idx
|
||||||
|
|
||||||
def get_repl(
|
@property
|
||||||
self,
|
def repl_unit(self) -> list[int]:
|
||||||
mm_items: list[_T],
|
return self.prompt_repl.repl_unit.token_ids
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
item_idx: int,
|
|
||||||
) -> list[int]:
|
|
||||||
prompt_repl = self.prompt_repl
|
|
||||||
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
|
||||||
return prompt_repl.repl_unit.token_ids * count
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
@ -404,15 +336,26 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
|
|||||||
def end_idx(self) -> int:
|
def end_idx(self) -> int:
|
||||||
return self.match.end()
|
return self.match.end()
|
||||||
|
|
||||||
def get_repl(
|
@property
|
||||||
self,
|
def repl_unit(self) -> str:
|
||||||
mm_items: list[_T],
|
return self.prompt_repl.repl_unit.text
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
item_idx: int,
|
|
||||||
) -> str:
|
class _PlaceholderInfo(NamedTuple):
|
||||||
prompt_repl = self.prompt_repl
|
modality: str
|
||||||
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
start_idx: int
|
||||||
return prompt_repl.repl_unit.text * count
|
unit: list[int]
|
||||||
|
unit_count: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self) -> int:
|
||||||
|
return len(self.unit) * self.unit_count
|
||||||
|
|
||||||
|
def to_range(self) -> PlaceholderRange:
|
||||||
|
return PlaceholderRange(
|
||||||
|
offset=self.start_idx,
|
||||||
|
length=self.length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_token_matches(
|
def find_token_matches(
|
||||||
@ -447,15 +390,17 @@ def _resolve_matches(
|
|||||||
Resolve :code:`matches` to ensure that there are no overlapping matches,
|
Resolve :code:`matches` to ensure that there are no overlapping matches,
|
||||||
and sort them such that earlier matches take priority over later ones.
|
and sort them such that earlier matches take priority over later ones.
|
||||||
"""
|
"""
|
||||||
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
|
seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \
|
||||||
for match in matches:
|
= [None] * len(prompt)
|
||||||
num_matches_by_idx[match.start_idx:match.end_idx] += 1
|
|
||||||
|
|
||||||
duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1)
|
for match in matches:
|
||||||
if len(duplicate_matches_idxs) > 0:
|
for idx in range(match.start_idx, match.end_idx):
|
||||||
raise ValueError("Unable to find a unique replacement "
|
if seen_matches[idx] is not None:
|
||||||
f"at indices={duplicate_matches_idxs} "
|
raise ValueError("Found overlapping matches "
|
||||||
f"of prompt={prompt}")
|
f"({seen_matches[idx]} and {match}) "
|
||||||
|
f"at index={idx} of prompt={prompt}")
|
||||||
|
|
||||||
|
seen_matches[idx] = match
|
||||||
|
|
||||||
return sorted(matches, key=lambda x: x.start_idx)
|
return sorted(matches, key=lambda x: x.start_idx)
|
||||||
|
|
||||||
@ -480,9 +425,12 @@ def _replace_matches(
|
|||||||
|
|
||||||
start_idx = match.start_idx
|
start_idx = match.start_idx
|
||||||
end_idx = match.end_idx
|
end_idx = match.end_idx
|
||||||
repl_ids = match.get_repl(mm_items, hf_inputs, item_idx)
|
repl_unit = match.repl_unit
|
||||||
|
repl_info = match.prompt_repl
|
||||||
|
repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx)
|
||||||
|
|
||||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids)
|
out_seqs.append(prompt[prev_end_idx:start_idx] +
|
||||||
|
repl_unit * repl_count)
|
||||||
prev_end_idx = end_idx
|
prev_end_idx = end_idx
|
||||||
next_idx_by_modality[modality] += 1
|
next_idx_by_modality[modality] += 1
|
||||||
|
|
||||||
@ -531,7 +479,57 @@ def replace_text_matches(
|
|||||||
return "".join(texts)
|
return "".join(texts)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalProcessor:
|
def _merge_placeholder_matches(
|
||||||
|
matches: Iterable[_PromptReplacementTokenMatch],
|
||||||
|
) -> Iterable[_PromptReplacementTokenMatch]:
|
||||||
|
current_match = None
|
||||||
|
|
||||||
|
for match in sorted(matches, key=lambda x: x.start_idx):
|
||||||
|
if current_match is None:
|
||||||
|
current_match = match
|
||||||
|
elif (current_match.prompt_repl == match.prompt_repl
|
||||||
|
and current_match.end_idx == match.start_idx):
|
||||||
|
current_match = _PromptReplacementTokenMatch(
|
||||||
|
current_match.prompt_repl,
|
||||||
|
match=_TokenMatch(current_match.start_idx, match.end_idx),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield current_match
|
||||||
|
current_match = match
|
||||||
|
|
||||||
|
if current_match is not None:
|
||||||
|
yield current_match
|
||||||
|
|
||||||
|
|
||||||
|
def iter_placeholders(
|
||||||
|
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||||
|
prompt: list[int],
|
||||||
|
*,
|
||||||
|
min_unit_count: int = 1,
|
||||||
|
) -> Iterable[_PlaceholderInfo]:
|
||||||
|
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
|
||||||
|
if min_unit_count <= 0:
|
||||||
|
raise ValueError("`min_unit_count` must be a positive integer")
|
||||||
|
|
||||||
|
matches = (_PromptReplacementTokenMatch(prompt_repl, match)
|
||||||
|
for prompt_repl in prompt_repls
|
||||||
|
if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0
|
||||||
|
for match in iter_token_matches(prompt, repl_unit))
|
||||||
|
|
||||||
|
for match in _merge_placeholder_matches(matches):
|
||||||
|
unit = match.repl_unit
|
||||||
|
placeholder = _PlaceholderInfo(
|
||||||
|
modality=match.modality,
|
||||||
|
start_idx=match.start_idx,
|
||||||
|
unit=unit,
|
||||||
|
unit_count=(match.end_idx - match.start_idx) // len(unit),
|
||||||
|
)
|
||||||
|
|
||||||
|
if placeholder.unit_count >= min_unit_count:
|
||||||
|
yield placeholder
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalProcessor(ABC):
|
||||||
"""
|
"""
|
||||||
Helper class to process multi-modal inputs to be used in vLLM.
|
Helper class to process multi-modal inputs to be used in vLLM.
|
||||||
"""
|
"""
|
||||||
@ -546,6 +544,12 @@ class MultiModalProcessor:
|
|||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|
||||||
|
def _get_hf_processor(self) -> ProcessorMixin:
|
||||||
|
return self.ctx.get_hf_processor()
|
||||||
|
|
||||||
|
def _get_tokenizer(self) -> AnyTokenizer:
|
||||||
|
return self.ctx.tokenizer
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -562,13 +566,13 @@ class MultiModalProcessor:
|
|||||||
# To avoid false positives from multi-input when detecting
|
# To avoid false positives from multi-input when detecting
|
||||||
# whether placeholder tokens have been inserted, in case
|
# whether placeholder tokens have been inserted, in case
|
||||||
# the target sequence is a subset of the replacement tokens
|
# the target sequence is a subset of the replacement tokens
|
||||||
min_placeholder_count: int = 16,
|
min_unit_count: int = 16,
|
||||||
) -> list[_PlaceholderInfo]:
|
) -> list[_PlaceholderInfo]:
|
||||||
return list(
|
return list(
|
||||||
iter_placeholders(
|
iter_placeholders(
|
||||||
all_prompt_repls,
|
all_prompt_repls,
|
||||||
new_token_ids,
|
new_token_ids,
|
||||||
min_placeholder_count=min_placeholder_count,
|
min_unit_count=min_unit_count,
|
||||||
))
|
))
|
||||||
|
|
||||||
def _apply_hf_processor(
|
def _apply_hf_processor(
|
||||||
@ -577,19 +581,49 @@ class MultiModalProcessor:
|
|||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
mm_processor_kwargs: Mapping[str, object],
|
mm_processor_kwargs: Mapping[str, object],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
hf_processor = self.ctx.get_hf_processor()
|
hf_processor = self._get_hf_processor()
|
||||||
|
|
||||||
return hf_processor(
|
processor_data = dict[str, Any]()
|
||||||
text=prompt, # type: ignore
|
passthrough_data = dict[str, Any]()
|
||||||
**mm_data,
|
for k, v in mm_data.items():
|
||||||
**mm_processor_kwargs,
|
# TODO: Make a separate modality for embedding inputs
|
||||||
)
|
# to avoid confusion
|
||||||
|
if k in ("image", "video", "audio"):
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim == 3:
|
||||||
|
# Pass through embedding inputs (single)
|
||||||
|
passthrough_data[f"{k}_embeds"] = [v]
|
||||||
|
elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
|
||||||
|
# Pass through embedding inputs (multi)
|
||||||
|
passthrough_data[f"{k}_embeds"] = v
|
||||||
|
else:
|
||||||
|
# Map keys to plural form, e.g.: image -> images
|
||||||
|
processor_data[f"{k}s"] = v
|
||||||
|
else:
|
||||||
|
processor_data[k] = v
|
||||||
|
|
||||||
|
try:
|
||||||
|
hf_inputs = hf_processor(
|
||||||
|
text=prompt, # type: ignore
|
||||||
|
**processor_data,
|
||||||
|
**mm_processor_kwargs,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
data = dict(text=prompt, **processor_data)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to apply {type(hf_processor).__name__} "
|
||||||
|
f"on data={data} with kwargs={mm_processor_kwargs}") from exc
|
||||||
|
|
||||||
|
hf_inputs.update(passthrough_data)
|
||||||
|
|
||||||
|
return hf_inputs
|
||||||
|
|
||||||
def _bind_prompt_replacements(
|
def _bind_prompt_replacements(
|
||||||
self,
|
self,
|
||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
) -> list[_BoundPromptReplacement[Any]]:
|
) -> list[_BoundPromptReplacement[Any]]:
|
||||||
tokenizer = self.ctx.tokenizer
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
prompt_repl.bind(modality, tokenizer)
|
prompt_repl.bind(modality, tokenizer)
|
||||||
@ -604,7 +638,7 @@ class MultiModalProcessor:
|
|||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||||
tokenizer = self.ctx.tokenizer
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
mm_items = to_multi_format(mm_data)
|
mm_items = to_multi_format(mm_data)
|
||||||
token_matches = find_token_matches(token_ids, prompt_repls)
|
token_matches = find_token_matches(token_ids, prompt_repls)
|
||||||
@ -620,7 +654,7 @@ class MultiModalProcessor:
|
|||||||
# of the search text in the prompt, we instead perform string
|
# of the search text in the prompt, we instead perform string
|
||||||
# replacement on the decoded token IDs, then encode them back.
|
# replacement on the decoded token IDs, then encode them back.
|
||||||
if all(
|
if all(
|
||||||
len(matches) >= len(mm_data[modality])
|
len(matches) >= len(mm_items[modality])
|
||||||
for modality, matches in full_groupby_modality(token_matches)
|
for modality, matches in full_groupby_modality(token_matches)
|
||||||
): # yapf: disable
|
): # yapf: disable
|
||||||
token_ids = replace_token_matches(
|
token_ids = replace_token_matches(
|
||||||
@ -648,15 +682,6 @@ class MultiModalProcessor:
|
|||||||
|
|
||||||
placeholders = self._find_placeholders(matched_repls, token_ids)
|
placeholders = self._find_placeholders(matched_repls, token_ids)
|
||||||
|
|
||||||
# Sanity check
|
|
||||||
assert len(placeholders) == len(matched_repls), dict(
|
|
||||||
# Log this information for easier debugging
|
|
||||||
text=text,
|
|
||||||
token_ids=token_ids,
|
|
||||||
placeholders=placeholders,
|
|
||||||
matched_repls=matched_repls,
|
|
||||||
)
|
|
||||||
|
|
||||||
return token_ids, text, placeholders
|
return token_ids, text, placeholders
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@ -678,7 +703,7 @@ class MultiModalProcessor:
|
|||||||
3. Extract information about the placeholder tokens from the
|
3. Extract information about the placeholder tokens from the
|
||||||
processed token IDs.
|
processed token IDs.
|
||||||
"""
|
"""
|
||||||
tokenizer = self.ctx.tokenizer
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
|
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
|
||||||
mm_processor_kwargs)
|
mm_processor_kwargs)
|
||||||
@ -717,3 +742,59 @@ class MultiModalProcessor:
|
|||||||
mm_kwargs=mm_kwargs,
|
mm_kwargs=mm_kwargs,
|
||||||
mm_placeholders=mm_placeholders,
|
mm_placeholders=mm_placeholders,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_dummy_mm_kwargs(
|
||||||
|
self,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalKwargs:
|
||||||
|
"""
|
||||||
|
Build the input that corresponds to `mm_max_tokens` in
|
||||||
|
:meth:`get_dummy_data`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_dummy_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
mm_max_tokens: Mapping[str, int],
|
||||||
|
) -> DummyData:
|
||||||
|
# Avoid circular import
|
||||||
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
|
mm_placeholders = dict[str, _PlaceholderInfo]()
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
for modality, max_tokens in mm_max_tokens.items():
|
||||||
|
if max_tokens == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = self.metadata[modality]
|
||||||
|
repl = metadata.prompt_repls[0].bind(modality, tokenizer)
|
||||||
|
repl_token_ids = repl.repl_unit.token_ids
|
||||||
|
|
||||||
|
placeholders = _PlaceholderInfo(
|
||||||
|
modality=modality,
|
||||||
|
start_idx=offset,
|
||||||
|
unit=repl_token_ids,
|
||||||
|
unit_count=max_tokens // len(repl_token_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_placeholders[modality] = placeholders
|
||||||
|
offset += placeholders.length
|
||||||
|
|
||||||
|
prompt_token_ids = flatten_2d_lists(
|
||||||
|
[p.unit * p.unit_count for p in mm_placeholders.values()])
|
||||||
|
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
|
||||||
|
|
||||||
|
return DummyData(
|
||||||
|
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||||
|
multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
|
||||||
|
multi_modal_placeholders={
|
||||||
|
modality: [p.to_range()]
|
||||||
|
for modality, p in mm_placeholders.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from .audio import AudioPlugin
|
|||||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||||
from .image import ImagePlugin
|
from .image import ImagePlugin
|
||||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||||
from .processing import MultiModalProcessor
|
from .processing import MultiModalProcessingMetadata, MultiModalProcessor
|
||||||
from .video import VideoPlugin
|
from .video import VideoPlugin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -200,6 +200,27 @@ class MultiModalRegistry:
|
|||||||
"""
|
"""
|
||||||
return self.register_max_multimodal_tokens("image", max_mm_tokens)
|
return self.register_max_multimodal_tokens("image", max_mm_tokens)
|
||||||
|
|
||||||
|
def get_max_tokens_by_modality(
|
||||||
|
self,
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
"""
|
||||||
|
Get the maximum number of tokens from each modality
|
||||||
|
for profiling the memory usage of a model.
|
||||||
|
|
||||||
|
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||||
|
"""
|
||||||
|
limits_per_plugin = self._limits_by_model[model_config]
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: (limits_per_plugin[key] *
|
||||||
|
plugin.get_max_multimodal_tokens(model_config))
|
||||||
|
for key, plugin in self._plugins.items()
|
||||||
|
}
|
||||||
|
|
||||||
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
||||||
"""
|
"""
|
||||||
Get the maximum number of multi-modal tokens
|
Get the maximum number of multi-modal tokens
|
||||||
@ -210,11 +231,7 @@ class MultiModalRegistry:
|
|||||||
Note:
|
Note:
|
||||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||||
"""
|
"""
|
||||||
limits_per_plugin = self._limits_by_model[model_config]
|
return sum(self.get_max_tokens_by_modality(model_config).values())
|
||||||
|
|
||||||
return sum((limits_per_plugin[key] *
|
|
||||||
plugin.get_max_multimodal_tokens(model_config))
|
|
||||||
for key, plugin in self._plugins.items())
|
|
||||||
|
|
||||||
def init_mm_limits_per_prompt(
|
def init_mm_limits_per_prompt(
|
||||||
self,
|
self,
|
||||||
@ -270,7 +287,8 @@ class MultiModalRegistry:
|
|||||||
factory: MultiModalProcessorFactory,
|
factory: MultiModalProcessorFactory,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Register a multi-modal processor to a model class.
|
Register a multi-modal processor to a model class. The processor
|
||||||
|
is constructed lazily, hence a factory method should be passed.
|
||||||
|
|
||||||
When the model receives multi-modal data, the provided function is
|
When the model receives multi-modal data, the provided function is
|
||||||
invoked to transform the data into a dictionary of model inputs.
|
invoked to transform the data into a dictionary of model inputs.
|
||||||
@ -293,6 +311,41 @@ class MultiModalRegistry:
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
def register_processor_by_metadata(
|
||||||
|
self,
|
||||||
|
metadata_factory: Callable[[InputProcessingContext],
|
||||||
|
MultiModalProcessingMetadata],
|
||||||
|
get_dummy_mm_kwargs: Callable[
|
||||||
|
[InputProcessingContext, Mapping[str, int]], MultiModalKwargs],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Convenience method to register a multi-modal processor to a model class
|
||||||
|
according to a function that constructs its metadata.
|
||||||
|
|
||||||
|
When the model receives multi-modal data, the provided function is
|
||||||
|
invoked to transform the data into a dictionary of model inputs.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :ref:`input_processing_pipeline`
|
||||||
|
- :ref:`enabling_multimodal_inputs`
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ConcreteMultiModalProcessor(MultiModalProcessor):
|
||||||
|
|
||||||
|
def _get_dummy_mm_kwargs(
|
||||||
|
self,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalKwargs:
|
||||||
|
return get_dummy_mm_kwargs(self.ctx, mm_counts)
|
||||||
|
|
||||||
|
def factory(ctx: InputProcessingContext):
|
||||||
|
return ConcreteMultiModalProcessor(
|
||||||
|
ctx=ctx,
|
||||||
|
metadata=metadata_factory(ctx),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.register_processor(factory)
|
||||||
|
|
||||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||||
"""
|
"""
|
||||||
Test whether a multi-modal processor is defined for a specific model.
|
Test whether a multi-modal processor is defined for a specific model.
|
||||||
|
|||||||
@ -12,6 +12,7 @@ class MMInputMapper:
|
|||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
):
|
):
|
||||||
|
self.model_config = model_config
|
||||||
self.mm_registry = mm_registry
|
self.mm_registry = mm_registry
|
||||||
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
|
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
|
||||||
model_config)
|
model_config)
|
||||||
|
|||||||
@ -7,7 +7,8 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
|||||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||||
|
MultiModalRegistry)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -101,10 +102,15 @@ class Processor:
|
|||||||
self.generation_config_fields, eos_token_id)
|
self.generation_config_fields, eos_token_id)
|
||||||
|
|
||||||
# Preprocess multi-modal data
|
# Preprocess multi-modal data
|
||||||
mm_inputs = self.mm_input_mapper.process_inputs(
|
if len(decoder_inputs.multi_modal_data) == 0:
|
||||||
decoder_inputs.multi_modal_data,
|
mm_inputs = None
|
||||||
decoder_inputs.mm_processor_kwargs) if len(
|
elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
|
||||||
decoder_inputs.multi_modal_data) > 0 else None
|
mm_inputs = [decoder_inputs.multi_modal_data]
|
||||||
|
else:
|
||||||
|
mm_inputs = self.mm_input_mapper.process_inputs(
|
||||||
|
decoder_inputs.multi_modal_data,
|
||||||
|
decoder_inputs.mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Make Request for Detokenizer.
|
# Make Request for Detokenizer.
|
||||||
detokenizer_request = DetokenizerRequest(
|
detokenizer_request = DetokenizerRequest(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user