mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 03:35:54 +08:00
[Model] Update Paligemma multimodal processing with PromptUpdate (#14015)
Signed-off-by: Kyle Huang <kylhuang@nvidia.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ed6ea06577
commit
1769928079
@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
- * `PaliGemmaForConditionalGeneration`\*
|
- * `PaliGemmaForConditionalGeneration`
|
||||||
* PaliGemma, PaliGemma 2
|
* PaliGemma (see note), PaliGemma 2 (see note)
|
||||||
* T + I<sup>E</sup>
|
* T + I<sup>E</sup>
|
||||||
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
|
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
|
||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
*
|
* ✅︎
|
||||||
- * `Phi3VForCausalLM`
|
- * `Phi3VForCausalLM`
|
||||||
* Phi-3-Vision, Phi-3.5-Vision
|
* Phi-3-Vision, Phi-3.5-Vision
|
||||||
* T + I<sup>E+</sup>
|
* T + I<sup>E+</sup>
|
||||||
|
|||||||
@ -116,9 +116,8 @@ VLM_TEST_SETTINGS = {
|
|||||||
"pixel_values"
|
"pixel_values"
|
||||||
),
|
),
|
||||||
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
||||||
dtype=("half" if current_platform.is_cpu() or current_platform.is_rocm()
|
dtype="bfloat16",
|
||||||
else ("half", "float")),
|
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
|
||||||
marks=[pytest.mark.core_model],
|
|
||||||
),
|
),
|
||||||
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
|
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
|
||||||
# once we upgraded to transformers>=4.49.0.
|
# once we upgraded to transformers>=4.49.0.
|
||||||
|
|||||||
@ -175,6 +175,8 @@ def _test_processing_correctness(
|
|||||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||||
"fixie-ai/ultravox-v0_4",
|
"fixie-ai/ultravox-v0_4",
|
||||||
"openai/whisper-large-v3",
|
"openai/whisper-large-v3",
|
||||||
|
"google/paligemma-3b-mix-224",
|
||||||
|
"google/paligemma2-3b-ft-docci-448",
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
|
|||||||
@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PaliGemmaConfig
|
from transformers import BatchFeature, PaliGemmaConfig
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
|
||||||
InputContext, token_inputs)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
|
MultiModalInputs, MultiModalKwargs,
|
||||||
|
NestedTensors)
|
||||||
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptIndexTargets,
|
||||||
|
PromptInsertion, PromptReplacement,
|
||||||
|
PromptUpdateDetails)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
|
||||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
|
||||||
@ -46,79 +50,6 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
|||||||
PaliGemmaImageEmbeddingInputs]
|
PaliGemmaImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
def get_max_paligemma_image_tokens(ctx: InputContext):
|
|
||||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
|
|
||||||
return get_max_siglip_image_tokens(vision_config)
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int]):
|
|
||||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
num_images = mm_counts["image"]
|
|
||||||
|
|
||||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
|
||||||
vision_config,
|
|
||||||
seq_len,
|
|
||||||
num_images,
|
|
||||||
image_token_id=hf_config.image_token_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
|
||||||
return DummyData(seq_data, mm_data, ranges)
|
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_paligemma(ctx: InputContext,
|
|
||||||
inputs: DecoderOnlyInputs):
|
|
||||||
|
|
||||||
"""
|
|
||||||
The correct prompt format needs to be:
|
|
||||||
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
|
|
||||||
|
|
||||||
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
|
|
||||||
""" # noqa
|
|
||||||
|
|
||||||
multi_modal_data = inputs.get("multi_modal_data")
|
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
model_config = ctx.model_config
|
|
||||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
||||||
|
|
||||||
tokenizer = cached_tokenizer_from_config(model_config)
|
|
||||||
image_feature_size = hf_config.text_config.num_image_tokens
|
|
||||||
image_token_str = tokenizer.decode(hf_config.image_token_index)
|
|
||||||
bos_token = tokenizer.decode(hf_config.bos_token_id)
|
|
||||||
image_token_str_pad = image_token_str * image_feature_size
|
|
||||||
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
|
|
||||||
|
|
||||||
orig_prompt = inputs.get("prompt")
|
|
||||||
orig_prompt_ids = inputs.get("prompt_token_ids")
|
|
||||||
|
|
||||||
if orig_prompt is not None and image_token_str in orig_prompt:
|
|
||||||
logger.warning(
|
|
||||||
"The image token '%s' was detected in the prompt and "
|
|
||||||
"will be removed. Please follow the proper prompt format"
|
|
||||||
" documented on HuggingFace.", image_token_str)
|
|
||||||
orig_prompt = orig_prompt.replace(image_token_str, "")
|
|
||||||
orig_prompt_ids.remove(hf_config.image_token_index)
|
|
||||||
|
|
||||||
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
|
|
||||||
|
|
||||||
# The PaliGemma 2 tokenizer does not include a starting BOS token
|
|
||||||
if orig_prompt_ids[0] != hf_config.bos_token_id:
|
|
||||||
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
|
|
||||||
|
|
||||||
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
|
||||||
|
|
||||||
# NOTE: Create a defensive copy of the original inputs
|
|
||||||
return token_inputs(prompt_token_ids=new_token_ids,
|
|
||||||
prompt=new_prompt,
|
|
||||||
multi_modal_data=multi_modal_data)
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaMultiModalProjector(nn.Module):
|
class PaliGemmaMultiModalProjector(nn.Module):
|
||||||
|
|
||||||
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
||||||
@ -131,12 +62,140 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
def get_hf_config(self):
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
return self.ctx.get_hf_config(PaliGemmaConfig)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": 1}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
return {"image": self.get_num_image_tokens()}
|
||||||
|
|
||||||
|
def get_num_image_tokens(self) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
return get_max_siglip_image_tokens(vision_config)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaDummyInputsBuilder(
|
||||||
|
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
max_image_size = vision_config.image_size
|
||||||
|
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=max_image_size,
|
||||||
|
height=max_image_size,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text="",
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaMultiModalProcessor(
|
||||||
|
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
if not mm_data:
|
||||||
|
prompt_ids = tokenizer.encode(prompt)
|
||||||
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||||
|
|
||||||
|
return super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
num_image_tokens = self.info.get_num_image_tokens()
|
||||||
|
image_tokens = [image_token_id] * num_image_tokens
|
||||||
|
|
||||||
|
bos_token_id = tokenizer.bos_token_id
|
||||||
|
assert isinstance(bos_token_id, int)
|
||||||
|
|
||||||
|
# Paligemma 1 and 2 have different tokenizer.add_bos_token
|
||||||
|
# Insert <image>*n + <bos> after <bos> for Paligemma 1
|
||||||
|
# Insert <image>*n + <bos> for Paligemma 2
|
||||||
|
return [
|
||||||
|
PromptInsertion(
|
||||||
|
modality="image",
|
||||||
|
target=PromptIndexTargets.prefix(
|
||||||
|
[bos_token_id] if tokenizer.add_bos_token else []),
|
||||||
|
insertion=PromptUpdateDetails(
|
||||||
|
full=image_tokens + [bos_token_id],
|
||||||
|
features=image_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> MultiModalInputs:
|
||||||
|
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||||
|
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
newline_prompt = "\n"
|
||||||
|
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
|
||||||
|
# Force to add newline at the end of prompt for paligemma's format
|
||||||
|
# This step can NOT be replacemented by current PromptUpdate methods
|
||||||
|
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
|
||||||
|
prompt_token_ids.append(newline_token_id)
|
||||||
|
mm_inputs["prompt_token_ids"] = prompt_token_ids
|
||||||
|
mm_inputs["prompt"] += newline_prompt
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
PaliGemmaMultiModalProcessor,
|
||||||
|
info=PaliGemmaProcessingInfo,
|
||||||
|
dummy_inputs=PaliGemmaDummyInputsBuilder)
|
||||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP, SupportsV0Only):
|
SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user