mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[Bugfix] Re-enable Gemma3 for V1 (#14980)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
05ccd0aa35
commit
61f412187d
@ -768,7 +768,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
*
|
||||
* ⚠️
|
||||
- * `GLM4VForCausalLM`<sup>^</sup>
|
||||
* GLM-4V
|
||||
* T + I
|
||||
@ -951,13 +951,10 @@ V0 correctly implements the model's attention pattern:
|
||||
|
||||
V1 currently uses a simplified attention pattern:
|
||||
- Uses causal attention for all tokens, including image tokens
|
||||
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
|
||||
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": True}`
|
||||
- Will be updated in the future to support the correct behavior
|
||||
- Does not support `"do_pan_and_scan": True`
|
||||
|
||||
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||
|
||||
For these reasons, `Gemma3ForConditionalGeneration` is supported only on V0 at the moment.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
|
||||
@ -19,7 +19,8 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches)
|
||||
iter_token_matches,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
@ -89,6 +90,58 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
assert all(match_len == len(match_ids) for match_len in match_lens)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("token_ids", "match_ids", "new_ids", "expected"),
|
||||
[
|
||||
([], [], [-1], []),
|
||||
([], [32000], [-1], []),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[32000],
|
||||
[-1],
|
||||
[-1, -1, -1],
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[32000, 32000],
|
||||
[-1],
|
||||
[-1, 32000],
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[32000, 32000, 32000],
|
||||
[-1],
|
||||
[-1],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[28747, 32000],
|
||||
[-1],
|
||||
[9833, -1, 32000, 32000, 9833, -1, 32000, 918],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[28747, 32000, 32000, 32000],
|
||||
[-1],
|
||||
[9833, -1, 9833, 28747, 32000, 32000, 918],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[28747, 0, 32000],
|
||||
[-1],
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
|
||||
result = replace_token_matches(token_ids, match_ids, new_ids)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "expected_by_key"),
|
||||
|
||||
@ -1,34 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import math
|
||||
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
|
||||
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, encode_tokens)
|
||||
BaseProcessingInfo, BoundPromptUpdate,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptTargetMatch,
|
||||
PromptUpdate, PromptUpdateDetails,
|
||||
encode_tokens, find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsV0Only)
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape: `(num_crops_total, num_channels, height, width)`
|
||||
Shape: `(num_patches_total, num_channels, height, width)`
|
||||
|
||||
`num_crops_total` is the total number of crops
|
||||
`num_patches_total` is the total number of patches
|
||||
over each image over each prompt in the batch.
|
||||
"""
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images,)`"""
|
||||
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Gemma3Config)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
|
||||
|
||||
@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
if not do_pan_and_scan:
|
||||
return 0
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
logger.warning_once(
|
||||
"`do_pan_and_scan=True` has suboptimal results on V1 "
|
||||
"because of the simplified attention pattern being used.")
|
||||
|
||||
# Based on Gemma3ImageProcessor.pan_and_scan
|
||||
if image_width >= image_height:
|
||||
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
|
||||
@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3Processor],
|
||||
) -> str:
|
||||
) -> PromptUpdateDetails:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
f"Here is the original image {image_token} and here are some "
|
||||
f"crops to help you see better {crops_image_tokens}")
|
||||
|
||||
return image_text.replace(image_token, processor.full_image_sequence)
|
||||
repl_full = image_text.replace(image_token,
|
||||
processor.full_image_sequence)
|
||||
repl_features = repl_full.strip("\n")
|
||||
|
||||
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
image_repl_tokens = encode_tokens(
|
||||
tokenizer,
|
||||
image_repl,
|
||||
image_repl.features,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return len(image_repl_tokens)
|
||||
@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
# NOTE: We need to separate the image tokens here because
|
||||
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
|
||||
# with the detection of prompt updates when the image tokens are
|
||||
# right next to each other
|
||||
return ProcessorInputs(
|
||||
prompt_text=" ".join([image_token] * num_images),
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
]
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
image_repl_features = [
|
||||
self.info.get_image_repl(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor).features
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
num_embeds = [
|
||||
len(image_repl_feature_tokens)
|
||||
for image_repl_feature_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
|
||||
embed_is_patch = [
|
||||
torch.tensor(image_repl_tokens) == image_token_id
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
num_crops = [
|
||||
self.info.get_num_crops(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor)
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
||||
|
||||
return processed_outputs
|
||||
@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
)
|
||||
]
|
||||
|
||||
def _apply_token_matches(
|
||||
self,
|
||||
prompt: list[int],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
token_ids = super()._apply_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||
# Since our replacement can insert "\n\n" next to "\n"
|
||||
# tokens, we have to combine them to be consistent with
|
||||
# the output of the tokenizer
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
newline_1 = vocab["\n"]
|
||||
newline_2 = vocab["\n\n"]
|
||||
newline_3 = vocab["\n\n\n"]
|
||||
newline_4 = vocab["\n\n\n\n"]
|
||||
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_1, newline_2],
|
||||
[newline_3],
|
||||
)
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_2, newline_1],
|
||||
[newline_3],
|
||||
)
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_2, newline_2],
|
||||
[newline_4],
|
||||
)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _find_mm_placeholders(
|
||||
self,
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
new_token_ids: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
newline_1 = vocab["\n"]
|
||||
newline_2 = vocab["\n\n"]
|
||||
newline_3 = vocab["\n\n\n"]
|
||||
newline_4 = vocab["\n\n\n\n"]
|
||||
|
||||
def get_repl_toks(tok: int) -> list[int]:
|
||||
if tok == newline_3:
|
||||
return [newline_1, newline_2]
|
||||
if tok == newline_4:
|
||||
return [newline_2, newline_2]
|
||||
|
||||
return [tok]
|
||||
|
||||
repl_token_ids = list[int]()
|
||||
repl_orig_idxs = list[int]()
|
||||
for orig_idx, orig_tok in enumerate(new_token_ids):
|
||||
repl_toks = get_repl_toks(orig_tok)
|
||||
repl_token_ids.extend(repl_toks)
|
||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||
|
||||
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
|
||||
mm_item_counts)
|
||||
|
||||
return {
|
||||
modality: [
|
||||
PlaceholderFeaturesInfo(
|
||||
modality=p.modality,
|
||||
item_idx=p.item_idx,
|
||||
start_idx=repl_orig_idxs[p.start_idx],
|
||||
tokens=p.tokens,
|
||||
) for p in placeholders
|
||||
]
|
||||
for modality, placeholders in repls.items()
|
||||
}
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(nn.Module):
|
||||
|
||||
@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
info=Gemma3ProcessingInfo,
|
||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA, SupportsV0Only):
|
||||
SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.language_model.sampler
|
||||
@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
num_embeds = kwargs.pop("num_embeds", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops values. "
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_crops=num_crops,
|
||||
num_patches=num_crops + 1,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma3ImageInputs,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
vision_outputs = self._image_pixels_to_features(
|
||||
num_patches = image_input["num_patches"]
|
||||
|
||||
image_features = self._image_pixels_to_features(
|
||||
self.vision_tower,
|
||||
pixel_values,
|
||||
)
|
||||
return self.multi_modal_projector(vision_outputs)
|
||||
image_embeds = self.multi_modal_projector(image_features)
|
||||
|
||||
return image_embeds.split(num_patches.tolist())
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if kwargs.get("v0_path", False):
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_index)
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
kwargs.update({"v0_path": True})
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
kwargs = self.prepare_attn_masks(
|
||||
input_ids,
|
||||
positions,
|
||||
mask_dtype=vision_embeddings.dtype,
|
||||
**kwargs)
|
||||
mask_dtype=self.dtype,
|
||||
**kwargs,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.jsontree import json_map_leaves
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||
select_patch_features)
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor(
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
num_patches = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
num_embeds = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
# Each image may result to masks of different sizes, so we need to
|
||||
# later use `num_patches` to get per-image masks.
|
||||
# later use `num_embeds` to get per-image masks.
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["num_patches"] = num_patches
|
||||
processed_outputs["num_embeds"] = num_embeds
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_patches = kwargs.pop("num_patches")
|
||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_patches. "
|
||||
f"Got type: {type(num_patches)}")
|
||||
num_embeds = kwargs.pop("num_embeds")
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_patches=num_patches,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def _get_mm_embeds(
|
||||
self,
|
||||
features: torch.Tensor, # Shape: (num_patch, d)
|
||||
num_patches: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
||||
"""
|
||||
# Insert columns of nan values according to `embed_is_patch`. This work
|
||||
# ideally should be done in `_process_image_input`, but
|
||||
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
||||
# put the logic here.
|
||||
# FIXME: Move this logic to `_process_image_input` when v0 is
|
||||
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
||||
num_patches_per_image: list[int] = num_patches.tolist()
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_patches_per_image), *features.shape[1:]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features
|
||||
|
||||
return embeds_flat.split(num_patches_per_image)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return vision_embeddings
|
||||
|
||||
return flatten_2d_lists(
|
||||
self._get_mm_embeds(*args) for args in zip(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
vision_embeddings,
|
||||
image_input["num_patches"],
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
# Extract the patch tokens
|
||||
patch_embeddings = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
cast(NestedTensors, patch_embeddings),
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
||||
SiluAndMul)
|
||||
@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import select_patch_features
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
if multimodal_embeddings is not None:
|
||||
assert self.img_patch_id is not None
|
||||
|
||||
# Extract the patch tokens scattered in _get_mm_embeds
|
||||
patch_embeddings = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
cast(NestedTensors, patch_embeddings),
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.img_patch_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -22,7 +22,6 @@ from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@ -48,7 +47,8 @@ from vllm.utils import flatten_2d_lists
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
|
||||
scatter_patch_features, select_patch_features)
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@ -77,7 +77,7 @@ class PixtralImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@ -153,7 +153,7 @@ class PixtralProcessorAdapter:
|
||||
images_processed = list[torch.Tensor]()
|
||||
images_tokens = list[torch.Tensor]()
|
||||
images_embed_is_patch = list[torch.Tensor]()
|
||||
images_num_patches = list[int]()
|
||||
images_num_embeds = list[int]()
|
||||
|
||||
for image in images:
|
||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||
@ -163,13 +163,13 @@ class PixtralProcessorAdapter:
|
||||
images_processed.append(image_processed)
|
||||
images_tokens.append(image_tokens)
|
||||
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||
images_num_patches.append(len(image_tokens))
|
||||
images_num_embeds.append(len(image_tokens))
|
||||
|
||||
return {
|
||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images": images_processed,
|
||||
"embed_is_patch": images_embed_is_patch,
|
||||
"num_patches": torch.tensor(images_num_patches),
|
||||
"num_embeds": torch.tensor(images_num_embeds),
|
||||
}
|
||||
|
||||
|
||||
@ -273,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
return dict(
|
||||
images=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -394,16 +394,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_patches = kwargs.pop("num_patches")
|
||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_patches. "
|
||||
f"Got type: {type(num_patches)}")
|
||||
num_embeds = kwargs.pop("num_embeds")
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
return PixtralImagePixelInputs(
|
||||
type="pixel_values",
|
||||
images=flatten_bn(images),
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_patches=num_patches,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
@ -433,33 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def _get_mm_embeds(
|
||||
self,
|
||||
features: torch.Tensor, # Shape: (num_patch, d)
|
||||
num_patches: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
||||
"""
|
||||
# Insert columns of nan values according to `embed_is_patch`. This work
|
||||
# ideally should be done in `_process_image_input`, but
|
||||
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
||||
# put the logic here.
|
||||
# FIXME: Move this logic to `_process_image_input` when v0 is
|
||||
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
||||
num_patches_per_image: list[int] = num_patches.tolist()
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_patches_per_image), *features.shape[1:]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features
|
||||
|
||||
return embeds_flat.split(num_patches_per_image)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@ -472,9 +445,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
self._get_mm_embeds(*args) for args in zip(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_patches"],
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
@ -485,15 +458,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
# Extract the patch tokens
|
||||
patch_embeddings = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
cast(NestedTensors, patch_embeddings),
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.vision_args.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -9,9 +9,12 @@ from transformers import PretrainedConfig
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
|
||||
from .interfaces import MultiModalEmbeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
|
||||
if post_layer_norm is not None and uses_last_layer:
|
||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
|
||||
def scatter_patch_features(
|
||||
features: torch.Tensor,
|
||||
num_embeds: torch.Tensor,
|
||||
embed_is_patch: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
The rest of the values in the tensor are set to NaN so that they
|
||||
can be filtered out by :func`select_patch_features`.
|
||||
|
||||
Args:
|
||||
features: The patch features, concatenated across each image.
|
||||
Shape: `(num_patch, feature_depth)`
|
||||
num_embeds: The number of image embeddings for each image.
|
||||
Shape: `(num_images,)`
|
||||
embed_is_patch: A boolean mask indicating which image embeddings
|
||||
correspond to patch tokens for each image.
|
||||
Shape: `(num_images, num_embeds)`
|
||||
"""
|
||||
num_embeds_per_image: list[int] = num_embeds.tolist()
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_embeds_per_image), features.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
|
||||
|
||||
return embeds_flat.split(num_embeds_per_image)
|
||||
|
||||
|
||||
def select_patch_features(
|
||||
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
|
||||
"""
|
||||
Given the outputs of :func:`scatter_patch_features`, return only
|
||||
the values that correspond to patch features.
|
||||
"""
|
||||
selected_features = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
return cast(MultiModalEmbeddings, selected_features)
|
||||
|
||||
@ -511,8 +511,35 @@ def iter_token_matches(
|
||||
start_idx += 1
|
||||
|
||||
|
||||
def replace_token_matches(
|
||||
token_ids: list[int],
|
||||
match_ids: list[int],
|
||||
new_ids: list[int],
|
||||
) -> list[int]:
|
||||
"""
|
||||
Replace each occurrence of :code:`match_ids` in :code:`token_ids`
|
||||
with :code:`new_ids`.
|
||||
|
||||
Note that empty matches are ignored.
|
||||
"""
|
||||
out_seqs = list[list[int]]()
|
||||
prev_end_idx = 0
|
||||
|
||||
for match in iter_token_matches(token_ids, match_ids):
|
||||
start_idx = match.start_idx
|
||||
end_idx = match.end_idx
|
||||
|
||||
out_seqs.append(token_ids[prev_end_idx:start_idx])
|
||||
out_seqs.append(new_ids)
|
||||
prev_end_idx = end_idx
|
||||
|
||||
out_seqs.append(token_ids[prev_end_idx:])
|
||||
|
||||
return flatten_2d_lists(out_seqs)
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptTargetMatch(ABC):
|
||||
class PromptTargetMatch(ABC):
|
||||
_origin: BoundPromptUpdate
|
||||
|
||||
@property
|
||||
@ -535,7 +562,7 @@ class _PromptTargetMatch(ABC):
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptTargetIndexMatch(_PromptTargetMatch):
|
||||
class _PromptTargetIndexMatch(PromptTargetMatch):
|
||||
match_idx: int
|
||||
|
||||
@property
|
||||
@ -548,7 +575,7 @@ class _PromptTargetIndexMatch(_PromptTargetMatch):
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptTargetTokenMatch(_PromptTargetMatch):
|
||||
class _PromptTargetTokenMatch(PromptTargetMatch):
|
||||
match: _TokenMatch
|
||||
|
||||
@property
|
||||
@ -561,7 +588,7 @@ class _PromptTargetTokenMatch(_PromptTargetMatch):
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptTargetTextMatch(_PromptTargetMatch):
|
||||
class _PromptTargetTextMatch(PromptTargetMatch):
|
||||
match: re.Match[str]
|
||||
|
||||
@property
|
||||
@ -594,7 +621,7 @@ class PlaceholderFeaturesInfo:
|
||||
def find_token_matches(
|
||||
prompt: list[int],
|
||||
prompt_updates: Sequence[BoundPromptUpdate],
|
||||
) -> Sequence[_PromptTargetMatch]:
|
||||
) -> Sequence[PromptTargetMatch]:
|
||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||
|
||||
def get_matches(update: BoundPromptUpdate):
|
||||
@ -620,7 +647,7 @@ def find_token_matches(
|
||||
def find_text_matches(
|
||||
prompt: str,
|
||||
prompt_updates: Sequence[BoundPromptUpdate],
|
||||
) -> Sequence[_PromptTargetMatch]:
|
||||
) -> Sequence[PromptTargetMatch]:
|
||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||
|
||||
def get_matches(update: BoundPromptUpdate):
|
||||
@ -645,15 +672,15 @@ def find_text_matches(
|
||||
|
||||
def _resolve_matches(
|
||||
prompt: PromptSeq,
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
) -> list[_PromptTargetMatch]:
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
) -> list[PromptTargetMatch]:
|
||||
"""
|
||||
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
|
||||
and sort them such that earlier matches take priority over later ones.
|
||||
"""
|
||||
matches = [m for matches in mm_matches.values() for m in matches]
|
||||
|
||||
seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
|
||||
seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
|
||||
|
||||
for match in matches:
|
||||
for idx in range(match.start_idx, match.end_idx):
|
||||
@ -669,7 +696,7 @@ def _resolve_matches(
|
||||
|
||||
def _apply_matches(
|
||||
prompt: _S,
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[_S]:
|
||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||
@ -718,7 +745,7 @@ def _apply_matches(
|
||||
|
||||
def apply_token_matches(
|
||||
prompt: list[int],
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||
@ -732,7 +759,7 @@ def apply_token_matches(
|
||||
|
||||
def apply_text_matches(
|
||||
prompt: str,
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> str:
|
||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||
@ -1055,14 +1082,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
Given the original multi-modal items for this modality
|
||||
and HF-processed data, output the updates to perform.
|
||||
|
||||
Notes:
|
||||
- You should not assume that HF processor always performs prompt
|
||||
updates: in :meth:`_apply_hf_processor_missing`, this method
|
||||
is called on text-only and multimodal-only inputs separately,
|
||||
instead of passing them in the same call.
|
||||
- The update information returned by this method is also used to
|
||||
determine the placeholder token positions for each multi-modal
|
||||
item.
|
||||
The information returned by this method is used to update token inputs
|
||||
which bypass the HF processor. It is also used to update the output of
|
||||
HF processor if the HF process does not apply prompt updates to text
|
||||
inputs.
|
||||
|
||||
Moreover, this information is critical to determine the token positions
|
||||
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
|
||||
for each multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1357,6 +1384,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
it = (update.bind(tokenizer) for update in prompt_updates)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
def _apply_token_matches(
|
||||
self,
|
||||
prompt: list[int],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
return apply_token_matches(prompt, mm_matches, mm_item_counts)
|
||||
|
||||
def _apply_text_matches(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> str:
|
||||
return apply_text_matches(prompt, mm_matches, mm_item_counts)
|
||||
|
||||
def _apply_prompt_updates(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
@ -1388,7 +1431,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_match_counts.get(modality, 0) >= item_count
|
||||
for modality, item_count in mm_item_counts.items()
|
||||
): # yapf: disable
|
||||
token_ids = apply_token_matches(
|
||||
token_ids = self._apply_token_matches(
|
||||
token_ids,
|
||||
mm_token_matches,
|
||||
mm_item_counts,
|
||||
@ -1406,7 +1449,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
modality: find_text_matches(text, updates)
|
||||
for modality, updates in mm_prompt_updates.items()
|
||||
}
|
||||
text = apply_text_matches(
|
||||
text = self._apply_text_matches(
|
||||
text,
|
||||
mm_text_matches,
|
||||
mm_item_counts,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user