[Bugfix] Fix profiling dummy data for Pixtral (#18677)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-25 22:05:30 +08:00 committed by GitHub
parent 3a886bd58c
commit 57fd13a707
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 151 additions and 168 deletions

View File

@ -9,15 +9,15 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage) UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config) cached_tokenizer_from_config,
encode_tokens)
from ....multimodal.utils import random_audio, random_image, random_video from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS from ...registry import HF_EXAMPLE_MODELS
@ -28,7 +28,6 @@ def _test_processing_correctness(
hit_rate: float, hit_rate: float,
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
ignore_mm_keys: Optional[set[str]] = None,
): ):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
@ -99,10 +98,23 @@ def _test_processing_correctness(
} }
mm_counts = {k: len(vs) for k, vs in mm_data.items()} mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len, # Mistral chat outputs tokens directly, rather than text prompts
mm_counts, if isinstance(tokenizer, MistralTokenizer):
).prompt_text images = mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=""),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
prompt = res.tokens
else:
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt
# Drop unnecessary keys and test single -> multi conversion # Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate: if rng.rand() < simplify_rate:
@ -112,67 +124,59 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1: elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0] mm_data[k] = mm_data[k][0]
if isinstance(tokenizer, MistralTokenizer): _test_processing_correctness_one(
_test_processing_correctness_mistral( model_config,
model_config, tokenizer,
tokenizer, prompt,
prompt, mm_data,
mm_data, baseline_processor,
baseline_processor, cached_processor,
cached_processor, batch_idx,
batch_idx, )
ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
def _test_processing_correctness_hf( # For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"mllama": False,
"ovis": False,
"ultravox": False,
"whisper": False,
}
_IGNORE_MM_KEYS = {
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
"ultravox": {"audio_features"},
}
def _test_processing_correctness_one(
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: AnyTokenizer,
prompt: str, prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor, baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor,
batch_idx: int, batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
): ):
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox", model_type = model_config.hf_config.model_type
"whisper"): ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs if isinstance(prompt, str):
# incorrect token ids. So we need use `add_special_tokens=False` here text_prompt = prompt
# to leave bos_token to be added by the processor. token_prompt = encode_tokens(
token_prompt = tokenizer.encode(prompt, add_special_tokens=False) tokenizer,
prompt,
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type),
)
else: else:
token_prompt = tokenizer.encode(prompt) # Mistral does not support decode_tokens with skip_special_tokens=False
text_prompt = None
baseline_result = baseline_processor.apply( token_prompt = prompt
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
_assert_inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
baseline_tokenized_result = baseline_processor.apply( baseline_tokenized_result = baseline_processor.apply(
token_prompt, token_prompt,
@ -180,56 +184,6 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
_assert_inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
_assert_inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens
# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply( cached_tokenized_result = cached_processor.apply(
token_prompt, token_prompt,
mm_data=mm_data, mm_data=mm_data,
@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
baseline_tokenized_result, baseline_tokenized_result,
cached_tokenized_result, cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
) )
if text_prompt is not None:
baseline_text_result = baseline_processor.apply(
text_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_text_result = cached_processor.apply(
text_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
_assert_inputs_equal(
baseline_text_result,
cached_text_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
)
_assert_inputs_equal(
baseline_text_result,
baseline_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
f"{token_prompt=}, {mm_data=})",
)
_assert_inputs_equal(
cached_text_result,
cached_tokenized_result,
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
f"{token_prompt=}, {mm_data=})",
)
# yapf: disable # yapf: disable
@pytest.mark.parametrize("model_id", [ @pytest.mark.parametrize("model_id", [
@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
"AIDC-AI/Ovis2-1B", "AIDC-AI/Ovis2-1B",
"google/paligemma-3b-mix-224", "google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448", "google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-3.5-vision-instruct",
"microsoft/Phi-4-multimodal-instruct", "microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409", "mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
@ -303,41 +293,6 @@ def test_processing_correctness(
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
): ):
ignore_mm_keys = None
if 'ultravox' in model_id:
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
ignore_mm_keys = {"audio_features"}
_test_processing_correctness(
model_id,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
ignore_mm_keys=ignore_mm_keys,
)
# yapf: disable
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness_phi3v(
model_id: str,
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
# HACK - this is an attempted workaround for the following bug
# https://github.com/huggingface/transformers/issues/34307
from transformers import AutoImageProcessor # noqa: F401
from transformers import AutoProcessor # noqa: F401
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
_test_processing_correctness( _test_processing_correctness(
model_id, model_id,
hit_rate=hit_rate, hit_rate=hit_rate,
@ -356,16 +311,10 @@ def _assert_inputs_equal(
if ignore_mm_keys is None: if ignore_mm_keys is None:
ignore_mm_keys = set() ignore_mm_keys = set()
if msg is None: assert "mm_kwargs" in a and "mm_kwargs" in b, msg
assert "mm_kwargs" in a and "mm_kwargs" in b
else:
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
for key in ignore_mm_keys: for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None) a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None) b["mm_kwargs"].pop(key, None)
if msg is None: assert a == b, msg
assert a == b
else:
assert a == b, msg

View File

@ -49,7 +49,7 @@ def test_profiling(
] * max_num_seqs ] * max_num_seqs
mm_kwargs = processor.apply( mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt_text, prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data, mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(), hf_processor_mm_kwargs=dict(),
)["mm_kwargs"] )["mm_kwargs"]

View File

@ -8,6 +8,8 @@ import pytest
from packaging.version import Version from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.config import TokenizerMode
@dataclass(frozen=True) @dataclass(frozen=True)
class _HfExamplesInfo: class _HfExamplesInfo:
@ -20,7 +22,7 @@ class _HfExamplesInfo:
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
"""Set the tokenizer to load for this architecture.""" """Set the tokenizer to load for this architecture."""
tokenizer_mode: str = "auto" tokenizer_mode: TokenizerMode = "auto"
"""Set the tokenizer type for this architecture.""" """Set the tokenizer type for this architecture."""
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
@ -388,8 +390,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True), trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral", tokenizer_mode="mistral"),
v0_only=True),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501 extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
trust_remote_code=True, trust_remote_code=True,
@ -400,7 +401,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B", "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
min_transformers_version="4.52"), min_transformers_version="4.52"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501 "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501
min_transformers_version="4.52"), min_transformers_version="4.52"),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501

View File

@ -9,7 +9,9 @@ from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image from PIL import Image
from transformers import PixtralVisionConfig, TensorType from transformers import PixtralVisionConfig, TensorType
@ -39,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes, BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config) cached_tokenizer_from_config)
@ -224,6 +226,28 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
num_images=num_images) num_images=num_images)
} }
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
dummy_images = dummy_mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=dummy_text),
*(ImageChunk(image=image) for image in dummy_images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
): ):
@ -275,8 +299,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super( (
)._cached_apply_hf_processor( prompt_ids,
mm_kwargs,
mm_hashes,
_,
) = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,

View File

@ -3,7 +3,7 @@
from abc import ABC from abc import ABC
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, NamedTuple, Optional, TypeVar, cast from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -27,7 +27,7 @@ class ProcessorInputs:
Represents the keyword arguments to Represents the keyword arguments to
{meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`. {meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
""" """
prompt_text: str prompt: Union[str, list[int]]
mm_data: MultiModalDataDict mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
@ -75,7 +75,12 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
"in an upcoming release.") "in an upcoming release.")
seq_len = self.info.ctx.model_config.max_model_len seq_len = self.info.ctx.model_config.max_model_len
return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text
prompt = self.get_dummy_processor_inputs(seq_len, mm_counts).prompt
if not isinstance(prompt, str):
prompt = self.info.get_tokenizer().decode(prompt)
return prompt
# TODO: @abstractmethod after transition # TODO: @abstractmethod after transition
def get_dummy_mm_data( def get_dummy_mm_data(
@ -101,7 +106,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data) return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data)
def _get_dummy_audios( def _get_dummy_audios(
self, self,
@ -177,7 +182,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len, mm_counts) seq_len, mm_counts)
return self.processor.apply( return self.processor.apply(
prompt=processor_inputs.prompt_text, prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data, mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
) )