[Bugfix] Fix profiling dummy data for Pixtral (#18677)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-25 22:05:30 +08:00 committed by GitHub
parent 3a886bd58c
commit 57fd13a707
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 151 additions and 168 deletions

View File

@ -9,15 +9,15 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens)
from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
@ -28,7 +28,6 @@ def _test_processing_correctness(
hit_rate: float,
num_batches: int,
simplify_rate: float,
ignore_mm_keys: Optional[set[str]] = None,
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
@ -99,10 +98,23 @@ def _test_processing_correctness(
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text
# Mistral chat outputs tokens directly, rather than text prompts
if isinstance(tokenizer, MistralTokenizer):
images = mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=""),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
prompt = res.tokens
else:
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt
# Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate:
@ -112,67 +124,59 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]
if isinstance(tokenizer, MistralTokenizer):
_test_processing_correctness_mistral(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
_test_processing_correctness_one(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
)
def _test_processing_correctness_hf(
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"mllama": False,
"ovis": False,
"ultravox": False,
"whisper": False,
}
_IGNORE_MM_KEYS = {
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
"ultravox": {"audio_features"},
}
def _test_processing_correctness_one(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt: str,
tokenizer: AnyTokenizer,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
"whisper"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
model_type = model_config.hf_config.model_type
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
if isinstance(prompt, str):
text_prompt = prompt
token_prompt = encode_tokens(
tokenizer,
prompt,
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type),
)
else:
token_prompt = tokenizer.encode(prompt)
baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
_assert_inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
# Mistral does not support decode_tokens with skip_special_tokens=False
text_prompt = None
token_prompt = prompt
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
@ -180,56 +184,6 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={},
)
_assert_inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
_assert_inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens
# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
)
if text_prompt is not None:
baseline_text_result = baseline_processor.apply(
text_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_text_result = cached_processor.apply(
text_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
_assert_inputs_equal(
baseline_text_result,
cached_text_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
)
_assert_inputs_equal(
baseline_text_result,
baseline_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
f"{token_prompt=}, {mm_data=})",
)
_assert_inputs_equal(
cached_text_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
f"{token_prompt=}, {mm_data=})",
)
# yapf: disable
@pytest.mark.parametrize("model_id", [
@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
"AIDC-AI/Ovis2-1B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-3.5-vision-instruct",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
@ -303,41 +293,6 @@ def test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
ignore_mm_keys = None
if 'ultravox' in model_id:
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
ignore_mm_keys = {"audio_features"}
_test_processing_correctness(
model_id,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
ignore_mm_keys=ignore_mm_keys,
)
# yapf: disable
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness_phi3v(
model_id: str,
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
# HACK - this is an attempted workaround for the following bug
# https://github.com/huggingface/transformers/issues/34307
from transformers import AutoImageProcessor # noqa: F401
from transformers import AutoProcessor # noqa: F401
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
_test_processing_correctness(
model_id,
hit_rate=hit_rate,
@ -356,16 +311,10 @@ def _assert_inputs_equal(
if ignore_mm_keys is None:
ignore_mm_keys = set()
if msg is None:
assert "mm_kwargs" in a and "mm_kwargs" in b
else:
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None)
if msg is None:
assert a == b
else:
assert a == b, msg
assert a == b, msg

View File

@ -49,7 +49,7 @@ def test_profiling(
] * max_num_seqs
mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt_text,
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]

View File

@ -8,6 +8,8 @@ import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.config import TokenizerMode
@dataclass(frozen=True)
class _HfExamplesInfo:
@ -20,7 +22,7 @@ class _HfExamplesInfo:
tokenizer: Optional[str] = None
"""Set the tokenizer to load for this architecture."""
tokenizer_mode: str = "auto"
tokenizer_mode: TokenizerMode = "auto"
"""Set the tokenizer type for this architecture."""
speculative_model: Optional[str] = None
@ -388,8 +390,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral",
v0_only=True),
tokenizer_mode="mistral"),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
trust_remote_code=True,
@ -400,7 +401,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
min_transformers_version="4.52"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501
min_transformers_version="4.52"),
min_transformers_version="4.52"),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501

View File

@ -9,7 +9,9 @@ from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import PixtralVisionConfig, TensorType
@ -39,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
@ -224,6 +226,28 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
num_images=num_images)
}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
dummy_images = dummy_mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=dummy_text),
*(ImageChunk(image=image) for image in dummy_images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
):
@ -275,8 +299,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
(
prompt_ids,
mm_kwargs,
mm_hashes,
_,
) = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@ -3,7 +3,7 @@
from abc import ABC
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, Optional, TypeVar, cast
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
import numpy as np
import numpy.typing as npt
@ -27,7 +27,7 @@ class ProcessorInputs:
Represents the keyword arguments to
{meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
"""
prompt_text: str
prompt: Union[str, list[int]]
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
@ -75,7 +75,12 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
"in an upcoming release.")
seq_len = self.info.ctx.model_config.max_model_len
return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text
prompt = self.get_dummy_processor_inputs(seq_len, mm_counts).prompt
if not isinstance(prompt, str):
prompt = self.info.get_tokenizer().decode(prompt)
return prompt
# TODO: @abstractmethod after transition
def get_dummy_mm_data(
@ -101,7 +106,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data)
return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data)
def _get_dummy_audios(
self,
@ -177,7 +182,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len, mm_counts)
return self.processor.apply(
prompt=processor_inputs.prompt_text,
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)