mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 23:55:46 +08:00
[Bugfix] Fix profiling dummy data for Pixtral (#18677)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3a886bd58c
commit
57fd13a707
@ -9,15 +9,15 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
|||||||
UserMessage)
|
UserMessage)
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||||
from vllm.multimodal.inputs import MultiModalInputs
|
from vllm.multimodal.inputs import MultiModalInputs
|
||||||
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
|
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
|
||||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||||
cached_tokenizer_from_config)
|
cached_tokenizer_from_config,
|
||||||
|
encode_tokens)
|
||||||
|
|
||||||
from ....multimodal.utils import random_audio, random_image, random_video
|
from ....multimodal.utils import random_audio, random_image, random_video
|
||||||
from ...registry import HF_EXAMPLE_MODELS
|
from ...registry import HF_EXAMPLE_MODELS
|
||||||
@ -28,7 +28,6 @@ def _test_processing_correctness(
|
|||||||
hit_rate: float,
|
hit_rate: float,
|
||||||
num_batches: int,
|
num_batches: int,
|
||||||
simplify_rate: float,
|
simplify_rate: float,
|
||||||
ignore_mm_keys: Optional[set[str]] = None,
|
|
||||||
):
|
):
|
||||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||||
model_info.check_available_online(on_fail="skip")
|
model_info.check_available_online(on_fail="skip")
|
||||||
@ -99,10 +98,23 @@ def _test_processing_correctness(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||||
prompt = dummy_inputs.get_dummy_processor_inputs(
|
|
||||||
model_config.max_model_len,
|
# Mistral chat outputs tokens directly, rather than text prompts
|
||||||
mm_counts,
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
).prompt_text
|
images = mm_data.get("image", [])
|
||||||
|
request = ChatCompletionRequest(messages=[
|
||||||
|
UserMessage(content=[
|
||||||
|
TextChunk(text=""),
|
||||||
|
*(ImageChunk(image=image) for image in images),
|
||||||
|
]),
|
||||||
|
])
|
||||||
|
res = tokenizer.mistral.encode_chat_completion(request)
|
||||||
|
prompt = res.tokens
|
||||||
|
else:
|
||||||
|
prompt = dummy_inputs.get_dummy_processor_inputs(
|
||||||
|
model_config.max_model_len,
|
||||||
|
mm_counts,
|
||||||
|
).prompt
|
||||||
|
|
||||||
# Drop unnecessary keys and test single -> multi conversion
|
# Drop unnecessary keys and test single -> multi conversion
|
||||||
if rng.rand() < simplify_rate:
|
if rng.rand() < simplify_rate:
|
||||||
@ -112,67 +124,59 @@ def _test_processing_correctness(
|
|||||||
elif len(mm_data[k]) == 1:
|
elif len(mm_data[k]) == 1:
|
||||||
mm_data[k] = mm_data[k][0]
|
mm_data[k] = mm_data[k][0]
|
||||||
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
_test_processing_correctness_one(
|
||||||
_test_processing_correctness_mistral(
|
model_config,
|
||||||
model_config,
|
tokenizer,
|
||||||
tokenizer,
|
prompt,
|
||||||
prompt,
|
mm_data,
|
||||||
mm_data,
|
baseline_processor,
|
||||||
baseline_processor,
|
cached_processor,
|
||||||
cached_processor,
|
batch_idx,
|
||||||
batch_idx,
|
)
|
||||||
ignore_mm_keys=ignore_mm_keys,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_test_processing_correctness_hf(
|
|
||||||
model_config,
|
|
||||||
tokenizer,
|
|
||||||
prompt,
|
|
||||||
mm_data,
|
|
||||||
baseline_processor,
|
|
||||||
cached_processor,
|
|
||||||
batch_idx,
|
|
||||||
ignore_mm_keys=ignore_mm_keys,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_processing_correctness_hf(
|
# For some multimodal models, tokenizer will always add bos_token
|
||||||
|
# at the beginning of prompt by default, causing hf_processor outputs
|
||||||
|
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||||
|
# to leave bos_token to be added by the processor.
|
||||||
|
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||||
|
"mllama": False,
|
||||||
|
"ovis": False,
|
||||||
|
"ultravox": False,
|
||||||
|
"whisper": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
_IGNORE_MM_KEYS = {
|
||||||
|
# In Ultravox, the audio_features can be different depending on padding
|
||||||
|
# The slight difference should not be a problem though, since
|
||||||
|
# attention_mask lets us ignore the difference.
|
||||||
|
"ultravox": {"audio_features"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _test_processing_correctness_one(
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: AnyTokenizer,
|
||||||
prompt: str,
|
prompt: Union[str, list[int]],
|
||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
baseline_processor: BaseMultiModalProcessor,
|
baseline_processor: BaseMultiModalProcessor,
|
||||||
cached_processor: BaseMultiModalProcessor,
|
cached_processor: BaseMultiModalProcessor,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
ignore_mm_keys: Optional[set[str]] = None,
|
|
||||||
):
|
):
|
||||||
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
|
model_type = model_config.hf_config.model_type
|
||||||
"whisper"):
|
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
|
||||||
# For some multimodal models, tokenizer will always add bos_token
|
|
||||||
# at the beginning of prompt by default, causing hf_processor outputs
|
if isinstance(prompt, str):
|
||||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
text_prompt = prompt
|
||||||
# to leave bos_token to be added by the processor.
|
token_prompt = encode_tokens(
|
||||||
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
token_prompt = tokenizer.encode(prompt)
|
# Mistral does not support decode_tokens with skip_special_tokens=False
|
||||||
|
text_prompt = None
|
||||||
baseline_result = baseline_processor.apply(
|
token_prompt = prompt
|
||||||
prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
cached_result = cached_processor.apply(
|
|
||||||
prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
|
|
||||||
_assert_inputs_equal(
|
|
||||||
baseline_result,
|
|
||||||
cached_result,
|
|
||||||
ignore_mm_keys=ignore_mm_keys,
|
|
||||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
|
||||||
)
|
|
||||||
|
|
||||||
baseline_tokenized_result = baseline_processor.apply(
|
baseline_tokenized_result = baseline_processor.apply(
|
||||||
token_prompt,
|
token_prompt,
|
||||||
@ -180,56 +184,6 @@ def _test_processing_correctness_hf(
|
|||||||
hf_processor_mm_kwargs={},
|
hf_processor_mm_kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
_assert_inputs_equal(
|
|
||||||
baseline_result,
|
|
||||||
baseline_tokenized_result,
|
|
||||||
ignore_mm_keys=ignore_mm_keys,
|
|
||||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
|
||||||
)
|
|
||||||
|
|
||||||
cached_tokenized_result = cached_processor.apply(
|
|
||||||
token_prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
|
|
||||||
_assert_inputs_equal(
|
|
||||||
cached_result,
|
|
||||||
cached_tokenized_result,
|
|
||||||
ignore_mm_keys=ignore_mm_keys,
|
|
||||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_processing_correctness_mistral(
|
|
||||||
model_config: ModelConfig,
|
|
||||||
tokenizer: MistralTokenizer,
|
|
||||||
prompt: str,
|
|
||||||
mm_data: MultiModalDataDict,
|
|
||||||
baseline_processor: BaseMultiModalProcessor,
|
|
||||||
cached_processor: BaseMultiModalProcessor,
|
|
||||||
batch_idx: int,
|
|
||||||
ignore_mm_keys: Optional[set[str]] = None,
|
|
||||||
):
|
|
||||||
images = mm_data.get("image", [])
|
|
||||||
if not isinstance(images, list):
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(messages=[
|
|
||||||
UserMessage(content=[
|
|
||||||
TextChunk(text=prompt),
|
|
||||||
*(ImageChunk(image=image) for image in images),
|
|
||||||
]),
|
|
||||||
])
|
|
||||||
res = tokenizer.mistral.encode_chat_completion(request)
|
|
||||||
token_prompt = res.tokens
|
|
||||||
|
|
||||||
# Mistral chat outputs tokens directly, rather than text prompts
|
|
||||||
baseline_tokenized_result = baseline_processor.apply(
|
|
||||||
token_prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
cached_tokenized_result = cached_processor.apply(
|
cached_tokenized_result = cached_processor.apply(
|
||||||
token_prompt,
|
token_prompt,
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
|
|||||||
baseline_tokenized_result,
|
baseline_tokenized_result,
|
||||||
cached_tokenized_result,
|
cached_tokenized_result,
|
||||||
ignore_mm_keys=ignore_mm_keys,
|
ignore_mm_keys=ignore_mm_keys,
|
||||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if text_prompt is not None:
|
||||||
|
baseline_text_result = baseline_processor.apply(
|
||||||
|
text_prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
cached_text_result = cached_processor.apply(
|
||||||
|
text_prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
_assert_inputs_equal(
|
||||||
|
baseline_text_result,
|
||||||
|
cached_text_result,
|
||||||
|
ignore_mm_keys=ignore_mm_keys,
|
||||||
|
msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
|
||||||
|
)
|
||||||
|
|
||||||
|
_assert_inputs_equal(
|
||||||
|
baseline_text_result,
|
||||||
|
baseline_tokenized_result,
|
||||||
|
ignore_mm_keys=ignore_mm_keys,
|
||||||
|
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
|
||||||
|
f"{token_prompt=}, {mm_data=})",
|
||||||
|
)
|
||||||
|
|
||||||
|
_assert_inputs_equal(
|
||||||
|
cached_text_result,
|
||||||
|
cached_tokenized_result,
|
||||||
|
ignore_mm_keys=ignore_mm_keys,
|
||||||
|
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
|
||||||
|
f"{token_prompt=}, {mm_data=})",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize("model_id", [
|
@pytest.mark.parametrize("model_id", [
|
||||||
@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
|
|||||||
"AIDC-AI/Ovis2-1B",
|
"AIDC-AI/Ovis2-1B",
|
||||||
"google/paligemma-3b-mix-224",
|
"google/paligemma-3b-mix-224",
|
||||||
"google/paligemma2-3b-ft-docci-448",
|
"google/paligemma2-3b-ft-docci-448",
|
||||||
|
"microsoft/Phi-3.5-vision-instruct",
|
||||||
"microsoft/Phi-4-multimodal-instruct",
|
"microsoft/Phi-4-multimodal-instruct",
|
||||||
"mistralai/Pixtral-12B-2409",
|
"mistralai/Pixtral-12B-2409",
|
||||||
"mistral-community/pixtral-12b",
|
"mistral-community/pixtral-12b",
|
||||||
@ -303,41 +293,6 @@ def test_processing_correctness(
|
|||||||
num_batches: int,
|
num_batches: int,
|
||||||
simplify_rate: float,
|
simplify_rate: float,
|
||||||
):
|
):
|
||||||
ignore_mm_keys = None
|
|
||||||
if 'ultravox' in model_id:
|
|
||||||
# In Ultravox, the audio_features can be different depending on padding
|
|
||||||
# The slight difference should not be a problem though, since
|
|
||||||
# attention_mask lets us ignore the difference.
|
|
||||||
ignore_mm_keys = {"audio_features"}
|
|
||||||
|
|
||||||
_test_processing_correctness(
|
|
||||||
model_id,
|
|
||||||
hit_rate=hit_rate,
|
|
||||||
num_batches=num_batches,
|
|
||||||
simplify_rate=simplify_rate,
|
|
||||||
ignore_mm_keys=ignore_mm_keys,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
|
||||||
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
|
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
|
||||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
|
||||||
# yapf: enable
|
|
||||||
def test_processing_correctness_phi3v(
|
|
||||||
model_id: str,
|
|
||||||
hit_rate: float,
|
|
||||||
num_batches: int,
|
|
||||||
simplify_rate: float,
|
|
||||||
):
|
|
||||||
# HACK - this is an attempted workaround for the following bug
|
|
||||||
# https://github.com/huggingface/transformers/issues/34307
|
|
||||||
from transformers import AutoImageProcessor # noqa: F401
|
|
||||||
from transformers import AutoProcessor # noqa: F401
|
|
||||||
|
|
||||||
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
|
||||||
|
|
||||||
_test_processing_correctness(
|
_test_processing_correctness(
|
||||||
model_id,
|
model_id,
|
||||||
hit_rate=hit_rate,
|
hit_rate=hit_rate,
|
||||||
@ -356,16 +311,10 @@ def _assert_inputs_equal(
|
|||||||
if ignore_mm_keys is None:
|
if ignore_mm_keys is None:
|
||||||
ignore_mm_keys = set()
|
ignore_mm_keys = set()
|
||||||
|
|
||||||
if msg is None:
|
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
||||||
assert "mm_kwargs" in a and "mm_kwargs" in b
|
|
||||||
else:
|
|
||||||
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
|
||||||
|
|
||||||
for key in ignore_mm_keys:
|
for key in ignore_mm_keys:
|
||||||
a["mm_kwargs"].pop(key, None)
|
a["mm_kwargs"].pop(key, None)
|
||||||
b["mm_kwargs"].pop(key, None)
|
b["mm_kwargs"].pop(key, None)
|
||||||
|
|
||||||
if msg is None:
|
assert a == b, msg
|
||||||
assert a == b
|
|
||||||
else:
|
|
||||||
assert a == b, msg
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ def test_profiling(
|
|||||||
] * max_num_seqs
|
] * max_num_seqs
|
||||||
|
|
||||||
mm_kwargs = processor.apply(
|
mm_kwargs = processor.apply(
|
||||||
prompt=dummy_mm_data.prompt_text,
|
prompt=dummy_mm_data.prompt,
|
||||||
mm_data=dummy_mm_data.mm_data,
|
mm_data=dummy_mm_data.mm_data,
|
||||||
hf_processor_mm_kwargs=dict(),
|
hf_processor_mm_kwargs=dict(),
|
||||||
)["mm_kwargs"]
|
)["mm_kwargs"]
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import pytest
|
|||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
|
|
||||||
|
from vllm.config import TokenizerMode
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class _HfExamplesInfo:
|
class _HfExamplesInfo:
|
||||||
@ -20,7 +22,7 @@ class _HfExamplesInfo:
|
|||||||
tokenizer: Optional[str] = None
|
tokenizer: Optional[str] = None
|
||||||
"""Set the tokenizer to load for this architecture."""
|
"""Set the tokenizer to load for this architecture."""
|
||||||
|
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: TokenizerMode = "auto"
|
||||||
"""Set the tokenizer type for this architecture."""
|
"""Set the tokenizer type for this architecture."""
|
||||||
|
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
@ -388,8 +390,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral"),
|
||||||
v0_only=True),
|
|
||||||
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
|
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
|
||||||
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
|
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -400,7 +401,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
|
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
|
||||||
min_transformers_version="4.52"),
|
min_transformers_version="4.52"),
|
||||||
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501
|
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501
|
||||||
min_transformers_version="4.52"),
|
min_transformers_version="4.52"),
|
||||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
||||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
||||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||||
|
|||||||
@ -9,7 +9,9 @@ from typing import Literal, Optional, TypedDict, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mistral_common.protocol.instruct.messages import ImageChunk
|
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||||
|
UserMessage)
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PixtralVisionConfig, TensorType
|
from transformers import PixtralVisionConfig, TensorType
|
||||||
@ -39,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, MultiModalHashes,
|
BaseProcessingInfo, MultiModalHashes,
|
||||||
PromptReplacement, PromptUpdate,
|
PromptReplacement, PromptUpdate,
|
||||||
PromptUpdateDetails)
|
PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||||
cached_tokenizer_from_config)
|
cached_tokenizer_from_config)
|
||||||
@ -224,6 +226,28 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
|||||||
num_images=num_images)
|
num_images=num_images)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
|
dummy_text = self.get_dummy_text(mm_counts)
|
||||||
|
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||||
|
dummy_images = dummy_mm_data.get("image", [])
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(messages=[
|
||||||
|
UserMessage(content=[
|
||||||
|
TextChunk(text=dummy_text),
|
||||||
|
*(ImageChunk(image=image) for image in dummy_images),
|
||||||
|
]),
|
||||||
|
])
|
||||||
|
res = tokenizer.mistral.encode_chat_completion(request)
|
||||||
|
dummy_tokens = res.tokens
|
||||||
|
|
||||||
|
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
|
||||||
|
|
||||||
|
|
||||||
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||||
):
|
):
|
||||||
@ -275,8 +299,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
|||||||
*,
|
*,
|
||||||
return_mm_hashes: bool,
|
return_mm_hashes: bool,
|
||||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||||
prompt_ids, mm_kwargs, mm_hashes, _ = super(
|
(
|
||||||
)._cached_apply_hf_processor(
|
prompt_ids,
|
||||||
|
mm_kwargs,
|
||||||
|
mm_hashes,
|
||||||
|
_,
|
||||||
|
) = super()._cached_apply_hf_processor(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
mm_data_items=mm_data_items,
|
mm_data_items=mm_data_items,
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Generic, NamedTuple, Optional, TypeVar, cast
|
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -27,7 +27,7 @@ class ProcessorInputs:
|
|||||||
Represents the keyword arguments to
|
Represents the keyword arguments to
|
||||||
{meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
|
{meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
|
||||||
"""
|
"""
|
||||||
prompt_text: str
|
prompt: Union[str, list[int]]
|
||||||
mm_data: MultiModalDataDict
|
mm_data: MultiModalDataDict
|
||||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||||
|
|
||||||
@ -75,7 +75,12 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
|||||||
"in an upcoming release.")
|
"in an upcoming release.")
|
||||||
|
|
||||||
seq_len = self.info.ctx.model_config.max_model_len
|
seq_len = self.info.ctx.model_config.max_model_len
|
||||||
return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text
|
|
||||||
|
prompt = self.get_dummy_processor_inputs(seq_len, mm_counts).prompt
|
||||||
|
if not isinstance(prompt, str):
|
||||||
|
prompt = self.info.get_tokenizer().decode(prompt)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
# TODO: @abstractmethod after transition
|
# TODO: @abstractmethod after transition
|
||||||
def get_dummy_mm_data(
|
def get_dummy_mm_data(
|
||||||
@ -101,7 +106,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
|||||||
dummy_text = self.get_dummy_text(mm_counts)
|
dummy_text = self.get_dummy_text(mm_counts)
|
||||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||||
|
|
||||||
return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data)
|
return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data)
|
||||||
|
|
||||||
def _get_dummy_audios(
|
def _get_dummy_audios(
|
||||||
self,
|
self,
|
||||||
@ -177,7 +182,7 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
seq_len, mm_counts)
|
seq_len, mm_counts)
|
||||||
|
|
||||||
return self.processor.apply(
|
return self.processor.apply(
|
||||||
prompt=processor_inputs.prompt_text,
|
prompt=processor_inputs.prompt,
|
||||||
mm_data=processor_inputs.mm_data,
|
mm_data=processor_inputs.mm_data,
|
||||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user