[Bugfix] Re-enable Gemma3 for V1 (#14980)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-19 14:58:22 +08:00 committed by GitHub
parent 05ccd0aa35
commit 61f412187d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 419 additions and 175 deletions

View File

@ -768,7 +768,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
* ✅︎
* ✅︎
*
* ⚠️
- * `GLM4VForCausalLM`<sup>^</sup>
* GLM-4V
* T + I
@ -951,13 +951,10 @@ V0 correctly implements the model's attention pattern:
V1 currently uses a simplified attention pattern:
- Uses causal attention for all tokens, including image tokens
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": True}`
- Will be updated in the future to support the correct behavior
- Does not support `"do_pan_and_scan": True`
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
For these reasons, `Gemma3ForConditionalGeneration` is supported only on V0 at the moment.
:::
:::{note}

View File

@ -19,7 +19,8 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches)
iter_token_matches,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
@ -89,6 +90,58 @@ def test_iter_token_matches(token_ids, match_ids, expected):
assert all(match_len == len(match_ids) for match_len in match_lens)
# yapf: disable
@pytest.mark.parametrize(
("token_ids", "match_ids", "new_ids", "expected"),
[
([], [], [-1], []),
([], [32000], [-1], []),
(
[32000, 32000, 32000],
[32000],
[-1],
[-1, -1, -1],
),
(
[32000, 32000, 32000],
[32000, 32000],
[-1],
[-1, 32000],
),
(
[32000, 32000, 32000],
[32000, 32000, 32000],
[-1],
[-1],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 32000],
[-1],
[9833, -1, 32000, 32000, 9833, -1, 32000, 918],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 32000, 32000, 32000],
[-1],
[9833, -1, 9833, 28747, 32000, 32000, 918],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 0, 32000],
[-1],
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
),
],
)
# yapf: enable
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
result = replace_token_matches(token_ids, match_ids, new_ids)
# Manually constructed results
assert result == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "expected_by_key"),

View File

@ -1,34 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
Tuple, TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, encode_tokens)
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
encode_tokens, find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsV0Only)
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(num_crops_total, num_channels, height, width)`
Shape: `(num_patches_total, num_channels, height, width)`
`num_crops_total` is the total number of crops
`num_patches_total` is the total number of patches
over each image over each prompt in the batch.
"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images,)`"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if not do_pan_and_scan:
return 0
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used.")
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> str:
) -> PromptUpdateDetails:
if processor is None:
processor = self.get_hf_processor()
@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
f"Here is the original image {image_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
return image_text.replace(image_token, processor.full_image_sequence)
repl_full = image_text.replace(image_token,
processor.full_image_sequence)
repl_features = repl_full.strip("\n")
return PromptUpdateDetails(full=repl_full, features=repl_features)
def get_num_image_tokens(
self,
@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
image_repl.features,
add_special_tokens=False,
)
return len(image_repl_tokens)
@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
num_images=num_images)
}
# NOTE: We need to separate the image tokens here because
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
# with the detection of prompt updates when the image tokens are
# right next to each other
return ProcessorInputs(
prompt_text=" ".join([image_token] * num_images),
prompt_text=image_token * num_images,
mm_data=mm_data,
)
@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor).features
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
return processed_outputs
@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
)
]
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
token_ids = super()._apply_token_matches(
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
# tokens, we have to combine them to be consistent with
# the output of the tokenizer
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
token_ids = replace_token_matches(
token_ids,
[newline_1, newline_2],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_1],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_2],
[newline_4],
)
return token_ids
def _find_mm_placeholders(
self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
def get_repl_toks(tok: int) -> list[int]:
if tok == newline_3:
return [newline_1, newline_2]
if tok == newline_4:
return [newline_2, newline_2]
return [tok]
repl_token_ids = list[int]()
repl_orig_idxs = list[int]()
for orig_idx, orig_tok in enumerate(new_token_ids):
repl_toks = get_repl_toks(orig_tok)
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
mm_item_counts)
return {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
) for p in placeholders
]
for modality, placeholders in repls.items()
}
class Gemma3MultiModalProjector(nn.Module):
@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA, SupportsV0Only):
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def sampler(self):
return self.language_model.sampler
@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
f"Got type: {type(pixel_values)}")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops values. "
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_crops=num_crops,
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _image_pixels_to_features(
@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> torch.Tensor:
) -> tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
vision_outputs = self._image_pixels_to_features(
num_patches = image_input["num_patches"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
return self.multi_modal_projector(vision_outputs)
image_embeds = self.multi_modal_projector(image_features)
return image_embeds.split(num_patches.tolist())
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
else:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds
def forward(self,
@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
kwargs = self.prepare_attn_masks(
input_ids,
positions,
mask_dtype=vision_embeddings.dtype,
**kwargs)
mask_dtype=self.dtype,
**kwargs,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
class LlavaImagePixelInputs(TypedDict):
@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_patches: Union[torch.Tensor, list[torch.Tensor]]
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_patches` to get per-image masks.
# later use `num_embeds` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_patches"] = num_patches
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
num_patches=num_patches,
num_embeds=num_embeds,
)
return LlavaImagePixelInputs(
@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image: list[int] = num_patches.tolist()
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["num_patches"],
image_input["num_embeds"],
image_input["embed_is_patch"],
))
@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds

View File

@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
from typing import List, Optional, Set, Tuple, TypedDict, Union
import numpy as np
import torch
@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul)
@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
from .vision import select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
if multimodal_embeddings is not None:
assert self.img_patch_id is not None
# Extract the patch tokens scattered in _get_mm_embeds
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
select_patch_features(multimodal_embeddings),
self.img_patch_id,
)
return inputs_embeds

View File

@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -22,7 +22,6 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -48,7 +47,8 @@ from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
scatter_patch_features, select_patch_features)
try:
from xformers import ops as xops
@ -77,7 +77,7 @@ class PixtralImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_patches: Union[torch.Tensor, list[torch.Tensor]]
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
@ -153,7 +153,7 @@ class PixtralProcessorAdapter:
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_patches = list[int]()
images_num_embeds = list[int]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
@ -163,13 +163,13 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_patches.append(len(image_tokens))
images_num_embeds.append(len(image_tokens))
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_patches": torch.tensor(images_num_patches),
"num_embeds": torch.tensor(images_num_embeds),
}
@ -273,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -394,16 +394,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_patches=num_patches,
num_embeds=num_embeds,
)
def _process_image_input(
@ -433,33 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image: list[int] = num_patches.tolist()
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
@ -472,9 +445,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return image_features
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_patches"],
image_input["num_embeds"],
image_input["embed_is_patch"],
))
@ -485,15 +458,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
select_patch_features(multimodal_embeddings),
self.vision_args.image_token_id,
)
return inputs_embeds

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch
from transformers import PretrainedConfig
@ -9,9 +9,12 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from .interfaces import MultiModalEmbeddings
logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig)
@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
def scatter_patch_features(
features: torch.Tensor,
num_embeds: torch.Tensor,
embed_is_patch: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
The rest of the values in the tensor are set to NaN so that they
can be filtered out by :func`select_patch_features`.
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
"""
num_embeds_per_image: list[int] = num_embeds.tolist()
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
return embeds_flat.split(num_embeds_per_image)
def select_patch_features(
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
"""
Given the outputs of :func:`scatter_patch_features`, return only
the values that correspond to patch features.
"""
selected_features = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
return cast(MultiModalEmbeddings, selected_features)

View File

@ -511,8 +511,35 @@ def iter_token_matches(
start_idx += 1
def replace_token_matches(
token_ids: list[int],
match_ids: list[int],
new_ids: list[int],
) -> list[int]:
"""
Replace each occurrence of :code:`match_ids` in :code:`token_ids`
with :code:`new_ids`.
Note that empty matches are ignored.
"""
out_seqs = list[list[int]]()
prev_end_idx = 0
for match in iter_token_matches(token_ids, match_ids):
start_idx = match.start_idx
end_idx = match.end_idx
out_seqs.append(token_ids[prev_end_idx:start_idx])
out_seqs.append(new_ids)
prev_end_idx = end_idx
out_seqs.append(token_ids[prev_end_idx:])
return flatten_2d_lists(out_seqs)
@dataclass(repr=False)
class _PromptTargetMatch(ABC):
class PromptTargetMatch(ABC):
_origin: BoundPromptUpdate
@property
@ -535,7 +562,7 @@ class _PromptTargetMatch(ABC):
@dataclass(repr=False)
class _PromptTargetIndexMatch(_PromptTargetMatch):
class _PromptTargetIndexMatch(PromptTargetMatch):
match_idx: int
@property
@ -548,7 +575,7 @@ class _PromptTargetIndexMatch(_PromptTargetMatch):
@dataclass(repr=False)
class _PromptTargetTokenMatch(_PromptTargetMatch):
class _PromptTargetTokenMatch(PromptTargetMatch):
match: _TokenMatch
@property
@ -561,7 +588,7 @@ class _PromptTargetTokenMatch(_PromptTargetMatch):
@dataclass(repr=False)
class _PromptTargetTextMatch(_PromptTargetMatch):
class _PromptTargetTextMatch(PromptTargetMatch):
match: re.Match[str]
@property
@ -594,7 +621,7 @@ class PlaceholderFeaturesInfo:
def find_token_matches(
prompt: list[int],
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
) -> Sequence[PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
@ -620,7 +647,7 @@ def find_token_matches(
def find_text_matches(
prompt: str,
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
) -> Sequence[PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
@ -645,15 +672,15 @@ def find_text_matches(
def _resolve_matches(
prompt: PromptSeq,
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
) -> list[_PromptTargetMatch]:
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
"""
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
for match in matches:
for idx in range(match.start_idx, match.end_idx):
@ -669,7 +696,7 @@ def _resolve_matches(
def _apply_matches(
prompt: _S,
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[_S]:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@ -718,7 +745,7 @@ def _apply_matches(
def apply_token_matches(
prompt: list[int],
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@ -732,7 +759,7 @@ def apply_token_matches(
def apply_text_matches(
prompt: str,
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> str:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@ -1055,14 +1082,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
Notes:
- You should not assume that HF processor always performs prompt
updates: in :meth:`_apply_hf_processor_missing`, this method
is called on text-only and multimodal-only inputs separately,
instead of passing them in the same call.
- The update information returned by this method is also used to
determine the placeholder token positions for each multi-modal
item.
The information returned by this method is used to update token inputs
which bypass the HF processor. It is also used to update the output of
HF processor if the HF process does not apply prompt updates to text
inputs.
Moreover, this information is critical to determine the token positions
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
for each multi-modal item.
"""
raise NotImplementedError
@ -1357,6 +1384,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
it = (update.bind(tokenizer) for update in prompt_updates)
return dict(full_groupby_modality(it))
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
return apply_token_matches(prompt, mm_matches, mm_item_counts)
def _apply_text_matches(
self,
prompt: str,
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> str:
return apply_text_matches(prompt, mm_matches, mm_item_counts)
def _apply_prompt_updates(
self,
token_ids: list[int],
@ -1388,7 +1431,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_match_counts.get(modality, 0) >= item_count
for modality, item_count in mm_item_counts.items()
): # yapf: disable
token_ids = apply_token_matches(
token_ids = self._apply_token_matches(
token_ids,
mm_token_matches,
mm_item_counts,
@ -1406,7 +1449,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
modality: find_text_matches(text, updates)
for modality, updates in mm_prompt_updates.items()
}
text = apply_text_matches(
text = self._apply_text_matches(
text,
mm_text_matches,
mm_item_counts,