mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 08:25:01 +08:00
[Model] Update Paligemma multimodal processing with PromptUpdate (#14015)
Signed-off-by: Kyle Huang <kylhuang@nvidia.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ed6ea06577
commit
1769928079
@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `PaliGemmaForConditionalGeneration`\*
|
||||
* PaliGemma, PaliGemma 2
|
||||
- * `PaliGemmaForConditionalGeneration`
|
||||
* PaliGemma (see note), PaliGemma 2 (see note)
|
||||
* T + I<sup>E</sup>
|
||||
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
*
|
||||
* ✅︎
|
||||
- * `Phi3VForCausalLM`
|
||||
* Phi-3-Vision, Phi-3.5-Vision
|
||||
* T + I<sup>E+</sup>
|
||||
|
||||
@ -116,9 +116,8 @@ VLM_TEST_SETTINGS = {
|
||||
"pixel_values"
|
||||
),
|
||||
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
||||
dtype=("half" if current_platform.is_cpu() or current_platform.is_rocm()
|
||||
else ("half", "float")),
|
||||
marks=[pytest.mark.core_model],
|
||||
dtype="bfloat16",
|
||||
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
|
||||
),
|
||||
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
|
||||
# once we upgraded to transformers>=4.49.0.
|
||||
|
||||
@ -175,6 +175,8 @@ def _test_processing_correctness(
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"fixie-ai/ultravox-v0_4",
|
||||
"openai/whisper-large-v3",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
|
||||
@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PaliGemmaConfig
|
||||
from transformers import BatchFeature, PaliGemmaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptReplacement,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
@ -46,79 +50,6 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
||||
PaliGemmaImageEmbeddingInputs]
|
||||
|
||||
|
||||
def get_max_paligemma_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
return get_max_siglip_image_tokens(vision_config)
|
||||
|
||||
|
||||
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
|
||||
def input_processor_for_paligemma(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
|
||||
"""
|
||||
The correct prompt format needs to be:
|
||||
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
|
||||
|
||||
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
|
||||
""" # noqa
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
image_feature_size = hf_config.text_config.num_image_tokens
|
||||
image_token_str = tokenizer.decode(hf_config.image_token_index)
|
||||
bos_token = tokenizer.decode(hf_config.bos_token_id)
|
||||
image_token_str_pad = image_token_str * image_feature_size
|
||||
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
|
||||
|
||||
orig_prompt = inputs.get("prompt")
|
||||
orig_prompt_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
if orig_prompt is not None and image_token_str in orig_prompt:
|
||||
logger.warning(
|
||||
"The image token '%s' was detected in the prompt and "
|
||||
"will be removed. Please follow the proper prompt format"
|
||||
" documented on HuggingFace.", image_token_str)
|
||||
orig_prompt = orig_prompt.replace(image_token_str, "")
|
||||
orig_prompt_ids.remove(hf_config.image_token_index)
|
||||
|
||||
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
|
||||
|
||||
# The PaliGemma 2 tokenizer does not include a starting BOS token
|
||||
if orig_prompt_ids[0] != hf_config.bos_token_id:
|
||||
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
|
||||
|
||||
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
||||
@ -131,12 +62,140 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
||||
class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(PaliGemmaConfig)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
return get_max_siglip_image_tokens(vision_config)
|
||||
|
||||
|
||||
class PaliGemmaDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
max_image_size = vision_config.image_size
|
||||
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProcessor(
|
||||
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
if not mm_data:
|
||||
prompt_ids = tokenizer.encode(prompt)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
# Paligemma 1 and 2 have different tokenizer.add_bos_token
|
||||
# Insert <image>*n + <bos> after <bos> for Paligemma 1
|
||||
# Insert <image>*n + <bos> for Paligemma 2
|
||||
return [
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.prefix(
|
||||
[bos_token_id] if tokenizer.add_bos_token else []),
|
||||
insertion=PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
newline_prompt = "\n"
|
||||
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
|
||||
# Force to add newline at the end of prompt for paligemma's format
|
||||
# This step can NOT be replacemented by current PromptUpdate methods
|
||||
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
|
||||
prompt_token_ids.append(newline_token_id)
|
||||
mm_inputs["prompt_token_ids"] = prompt_token_ids
|
||||
mm_inputs["prompt"] += newline_prompt
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PaliGemmaMultiModalProcessor,
|
||||
info=PaliGemmaProcessingInfo,
|
||||
dummy_inputs=PaliGemmaDummyInputsBuilder)
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsV0Only):
|
||||
SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user