diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index fbcea826e6c9b..5e5e7287f39eb 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -768,7 +768,7 @@ See [this page](#generative-models) for more information on how to use generativ * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. * ✅︎ * ✅︎ - * + * ⚠️ - * `GLM4VForCausalLM`^ * GLM-4V * T + I @@ -951,13 +951,10 @@ V0 correctly implements the model's attention pattern: V1 currently uses a simplified attention pattern: - Uses causal attention for all tokens, including image tokens -- Generates reasonable outputs but does not match the original model's attention for text + image inputs +- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": True}` - Will be updated in the future to support the correct behavior -- Does not support `"do_pan_and_scan": True` This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. - -For these reasons, `Gemma3ForConditionalGeneration` is supported only on V0 at the moment. ::: :::{note} diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index fbb7e507b10a5..2e6dde75dc917 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -19,7 +19,8 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo, apply_token_matches, find_mm_placeholders, find_text_matches, find_token_matches, - iter_token_matches) + iter_token_matches, + replace_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import (AnyTokenizer, @@ -89,6 +90,58 @@ def test_iter_token_matches(token_ids, match_ids, expected): assert all(match_len == len(match_ids) for match_len in match_lens) +# yapf: disable +@pytest.mark.parametrize( + ("token_ids", "match_ids", "new_ids", "expected"), + [ + ([], [], [-1], []), + ([], [32000], [-1], []), + ( + [32000, 32000, 32000], + [32000], + [-1], + [-1, -1, -1], + ), + ( + [32000, 32000, 32000], + [32000, 32000], + [-1], + [-1, 32000], + ), + ( + [32000, 32000, 32000], + [32000, 32000, 32000], + [-1], + [-1], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000], + [-1], + [9833, -1, 32000, 32000, 9833, -1, 32000, 918], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000, 32000, 32000], + [-1], + [9833, -1, 9833, 28747, 32000, 32000, 918], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 0, 32000], + [-1], + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + ), + ], +) +# yapf: enable +def test_replace_token_matches(token_ids, match_ids, new_ids, expected): + result = replace_token_matches(token_ids, match_ids, new_ids) + + # Manually constructed results + assert result == expected + + # yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "expected_by_key"), diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 27b254b9c5c84..62e55d64cf2ca 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -1,34 +1,43 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set, - Tuple, TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union import torch from torch import nn from transformers import BatchFeature, Gemma3Config, Gemma3Processor from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs +import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) +# yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, encode_tokens) + BaseProcessingInfo, BoundPromptUpdate, + PlaceholderFeaturesInfo, + PromptReplacement, PromptTargetMatch, + PromptUpdate, PromptUpdateDetails, + encode_tokens, find_mm_placeholders, + replace_token_matches) +# yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors +from vllm.utils import flatten_2d_lists from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP, SupportsV0Only) + SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features logger = init_logger(__name__) @@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values: torch.Tensor """ - Shape: `(num_crops_total, num_channels, height, width)` + Shape: `(num_patches_total, num_channels, height, width)` - `num_crops_total` is the total number of crops + `num_patches_total` is the total number of patches over each image over each prompt in the batch. """ - num_crops: torch.Tensor - """Shape: `(batch_size * num_images,)`""" + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_images, num_embeds)` + """ + + num_embeds: Union[torch.Tensor, list[torch.Tensor]] + """Shape: `(batch_size, num_images)`""" Gemma3ImageInputs = Gemma3ImagePixelInputs @@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs class Gemma3ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Gemma3Config) + def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) @@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): if not do_pan_and_scan: return 0 + if envs.VLLM_USE_V1: + logger.warning_once( + "`do_pan_and_scan=True` has suboptimal results on V1 " + "because of the simplified attention pattern being used.") + # Based on Gemma3ImageProcessor.pan_and_scan if image_width >= image_height: if image_width / image_height < pan_and_scan_min_ratio_to_activate: @@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): image_width: int, image_height: int, processor: Optional[Gemma3Processor], - ) -> str: + ) -> PromptUpdateDetails: if processor is None: processor = self.get_hf_processor() @@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): f"Here is the original image {image_token} and here are some " f"crops to help you see better {crops_image_tokens}") - return image_text.replace(image_token, processor.full_image_sequence) + repl_full = image_text.replace(image_token, + processor.full_image_sequence) + repl_features = repl_full.strip("\n") + + return PromptUpdateDetails(full=repl_full, features=repl_features) def get_num_image_tokens( self, @@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): image_repl_tokens = encode_tokens( tokenizer, - image_repl, + image_repl.features, add_special_tokens=False, ) return len(image_repl_tokens) @@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): num_images=num_images) } - # NOTE: We need to separate the image tokens here because - # encode("\n\n\n\n") != encode("\n\n") * 2, which interferes - # with the detection of prompt updates when the image tokens are - # right next to each other return ProcessorInputs( - prompt_text=" ".join([image_token] * num_images), + prompt_text=image_token * num_images, mm_data=mm_data, ) @@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ] hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_repl_features = [ + self.info.get_image_repl(image_width=size.width, + image_height=size.height, + processor=hf_processor).features + for size in image_sizes + ] + + tokenizer = self.info.get_tokenizer() + image_repls_feature_tokens = [ + tokenizer.encode(image_repl, add_special_tokens=False) + for image_repl in image_repl_features + ] + num_embeds = [ + len(image_repl_feature_tokens) + for image_repl_feature_tokens in image_repls_feature_tokens + ] + processed_outputs["num_embeds"] = torch.tensor(num_embeds) + + vocab = tokenizer.get_vocab() + image_token_id = vocab[tokenizer.image_token] + + embed_is_patch = [ + torch.tensor(image_repl_tokens) == image_token_id + for image_repl_tokens in image_repls_feature_tokens + ] + processed_outputs["embed_is_patch"] = embed_is_patch + num_crops = [ self.info.get_num_crops(image_width=size.width, image_height=size.height, processor=hf_processor) for size in image_sizes ] - processed_outputs["num_crops"] = torch.tensor(num_crops) return processed_outputs @@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", num_crops + 1), num_crops=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + num_embeds=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ) ] + def _apply_token_matches( + self, + prompt: list[int], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], + ) -> list[int]: + token_ids = super()._apply_token_matches( + prompt, + mm_matches, + mm_item_counts, + ) + + # "\n\n\n" and "\n\n\n\n" are single tokens + # Since our replacement can insert "\n\n" next to "\n" + # tokens, we have to combine them to be consistent with + # the output of the tokenizer + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + token_ids = replace_token_matches( + token_ids, + [newline_1, newline_2], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_1], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_2], + [newline_4], + ) + + return token_ids + + def _find_mm_placeholders( + self, + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + new_token_ids: list[int], + mm_item_counts: Mapping[str, int], + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + def get_repl_toks(tok: int) -> list[int]: + if tok == newline_3: + return [newline_1, newline_2] + if tok == newline_4: + return [newline_2, newline_2] + + return [tok] + + repl_token_ids = list[int]() + repl_orig_idxs = list[int]() + for orig_idx, orig_tok in enumerate(new_token_ids): + repl_toks = get_repl_toks(orig_tok) + repl_token_ids.extend(repl_toks) + repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) + + repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, + mm_item_counts) + + return { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=repl_orig_idxs[p.start_idx], + tokens=p.tokens, + ) for p in placeholders + ] + for modality, placeholders in repls.items() + } + class Gemma3MultiModalProjector(nn.Module): @@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module): info=Gemma3ProcessingInfo, dummy_inputs=Gemma3DummyInputsBuilder) class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA, SupportsV0Only): + SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + @property + def dtype(self): + return next(self.parameters()).dtype + @property def sampler(self): return self.language_model.sampler @@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self, **kwargs: object) -> Optional[Gemma3ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) num_crops = kwargs.pop("num_crops", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) + num_embeds = kwargs.pop("num_embeds", None) image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Gemma3 does not support image_embeds." if pixel_values is None: @@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, f"Got type: {type(pixel_values)}") if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_crops values. " + raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + if not isinstance(num_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_embeds. " + f"Got type: {type(num_embeds)}") + pixel_values = flatten_bn(pixel_values, concat=True) num_crops = flatten_bn(num_crops, concat=True) return Gemma3ImagePixelInputs( type="pixel_values", pixel_values=self._validate_pixel_values(pixel_values), - num_crops=num_crops, + num_patches=num_crops + 1, + embed_is_patch=embed_is_patch, + num_embeds=num_embeds, ) def _image_pixels_to_features( @@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def _process_image_input( self, image_input: Gemma3ImageInputs, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, ...]: assert self.vision_tower is not None pixel_values = image_input["pixel_values"] - vision_outputs = self._image_pixels_to_features( + num_patches = image_input["num_patches"] + + image_features = self._image_pixels_to_features( self.vision_tower, pixel_values, ) - return self.multi_modal_projector(vision_outputs) + image_embeds = self.multi_modal_projector(image_features) + + return image_embeds.split(num_patches.tolist()) def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + + image_features = self._process_image_input(image_input) + + if kwargs.get("v0_path", False): + return image_features + + return flatten_2d_lists( + scatter_patch_features(*args) for args in zip( + image_features, + image_input["num_embeds"], + image_input["embed_is_patch"], + )) def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: - if multimodal_embeddings is None: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - else: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.config.image_token_index, + ) return inputs_embeds def forward(self, @@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, @@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, kwargs = self.prepare_attn_masks( input_ids, positions, - mask_dtype=vision_embeddings.dtype, - **kwargs) + mask_dtype=self.dtype, + **kwargs, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3a8d184528d8b..441ccde046eb9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext -from vllm.jsontree import JSONTree, json_map_leaves +from vllm.jsontree import json_map_leaves from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, - NestedTensors) + MultiModalInputs, MultiModalKwargs) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vision_encoder_info +from .vision import (get_vision_encoder_info, scatter_patch_features, + select_patch_features) class LlavaImagePixelInputs(TypedDict): @@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict): Shape: `(batch_size, num_images, num_embeds)` """ - num_patches: Union[torch.Tensor, list[torch.Tensor]] + num_embeds: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size, num_images)`""" @@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor( image_height=pixel_value.shape[-2], ) for pixel_value in processed_outputs["pixel_values"] ] - num_patches = torch.tensor([(ncols + 1) * nrows - for ncols, nrows in tile_sizes]) + num_embeds = torch.tensor([(ncols + 1) * nrows + for ncols, nrows in tile_sizes]) # Each image may result to masks of different sizes, so we need to - # later use `num_patches` to get per-image masks. + # later use `num_embeds` to get per-image masks. embed_is_patch = [ torch.tensor(([True] * ncols + [False]) * nrows) for ncols, nrows in tile_sizes ] - processed_outputs["num_patches"] = num_patches + processed_outputs["num_embeds"] = num_embeds processed_outputs["embed_is_patch"] = embed_is_patch return processed_outputs @@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), - num_patches=MultiModalFieldConfig.batched("image"), + num_embeds=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") - num_patches = kwargs.pop("num_patches") - if not isinstance(num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_patches. " - f"Got type: {type(num_patches)}") + num_embeds = kwargs.pop("num_embeds") + if not isinstance(num_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_embeds. " + f"Got type: {type(num_embeds)}") return PixtralHFImagePixelInputs( type="pixel_values_pixtral", pixel_values=flatten_bn(pixel_values), embed_is_patch=embed_is_patch, - num_patches=num_patches, + num_embeds=num_embeds, ) return LlavaImagePixelInputs( @@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds - def _get_mm_embeds( - self, - features: torch.Tensor, # Shape: (num_patch, d) - num_patches: torch.Tensor, # Shape: (num_images,) - embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds) - ) -> tuple[torch.Tensor, ...]: - """Scatter the patch features into a contiguous tensor that corresponds - to the embedding tokens defined by the multimodal processor. - - Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. - """ - # Insert columns of nan values according to `embed_is_patch`. This work - # ideally should be done in `_process_image_input`, but - # `_process_image_input` is used in both V0 and V1 path. It's safer to - # put the logic here. - # FIXME: Move this logic to `_process_image_input` when v0 is - # deprecated. Merge this function with `Molmo._get_mm_embeds`. - num_patches_per_image: list[int] = num_patches.tolist() - - embeds_flat = features.new_full( - (sum(num_patches_per_image), *features.shape[1:]), - fill_value=torch.nan, - ) - embeds_flat[embed_is_patch.view(-1)] = features - - return embeds_flat.split(num_patches_per_image) - def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) @@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): return vision_embeddings return flatten_2d_lists( - self._get_mm_embeds(*args) for args in zip( + scatter_patch_features(*args) for args in zip( vision_embeddings, - image_input["num_patches"], + image_input["num_embeds"], image_input["embed_is_patch"], )) @@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - # Extract the patch tokens - patch_embeddings = json_map_leaves( - lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), - cast(JSONTree[torch.Tensor], multimodal_embeddings), - ) - inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - cast(NestedTensors, patch_embeddings), + select_patch_features(multimodal_embeddings), self.config.image_token_index, ) return inputs_embeds diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index c7f6cf461d523..3f0c644a5a866 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import List, Optional, Set, Tuple, TypedDict, Union, cast +from typing import List, Optional, Set, Tuple, TypedDict, Union import numpy as np import torch @@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) -from vllm.jsontree import JSONTree, json_map_leaves from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, SiluAndMul) @@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) +from .vision import select_patch_features # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, if multimodal_embeddings is not None: assert self.img_patch_id is not None - # Extract the patch tokens scattered in _get_mm_embeds - patch_embeddings = json_map_leaves( - lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), - cast(JSONTree[torch.Tensor], multimodal_embeddings), - ) - inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - cast(NestedTensors, patch_embeddings), + select_patch_features(multimodal_embeddings), self.img_patch_id, ) return inputs_embeds diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index dc3402e432149..5da69ce7fa061 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -22,7 +22,6 @@ from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.jsontree import JSONTree, json_map_leaves from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -48,7 +47,8 @@ from vllm.utils import flatten_2d_lists from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs, + scatter_patch_features, select_patch_features) try: from xformers import ops as xops @@ -77,7 +77,7 @@ class PixtralImagePixelInputs(TypedDict): Shape: `(batch_size, num_images, num_embeds)` """ - num_patches: Union[torch.Tensor, list[torch.Tensor]] + num_embeds: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size, num_images)`""" @@ -153,7 +153,7 @@ class PixtralProcessorAdapter: images_processed = list[torch.Tensor]() images_tokens = list[torch.Tensor]() images_embed_is_patch = list[torch.Tensor]() - images_num_patches = list[int]() + images_num_embeds = list[int]() for image in images: image_inputs = self.image_processor(ImageChunk(image=image)) @@ -163,13 +163,13 @@ class PixtralProcessorAdapter: images_processed.append(image_processed) images_tokens.append(image_tokens) images_embed_is_patch.append(image_tokens == image_token_id) - images_num_patches.append(len(image_tokens)) + images_num_embeds.append(len(image_tokens)) return { "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), "images": images_processed, "embed_is_patch": images_embed_is_patch, - "num_patches": torch.tensor(images_num_patches), + "num_embeds": torch.tensor(images_num_embeds), } @@ -273,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] return dict( images=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"), - num_patches=MultiModalFieldConfig.batched("image"), + num_embeds=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -394,16 +394,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") - num_patches = kwargs.pop("num_patches") - if not isinstance(num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_patches. " - f"Got type: {type(num_patches)}") + num_embeds = kwargs.pop("num_embeds") + if not isinstance(num_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_embeds. " + f"Got type: {type(num_embeds)}") return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), embed_is_patch=embed_is_patch, - num_patches=num_patches, + num_embeds=num_embeds, ) def _process_image_input( @@ -433,33 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds - def _get_mm_embeds( - self, - features: torch.Tensor, # Shape: (num_patch, d) - num_patches: torch.Tensor, # Shape: (num_images,) - embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds) - ) -> tuple[torch.Tensor, ...]: - """Scatter the patch features into a contiguous tensor that corresponds - to the embedding tokens defined by the multimodal processor. - - Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. - """ - # Insert columns of nan values according to `embed_is_patch`. This work - # ideally should be done in `_process_image_input`, but - # `_process_image_input` is used in both V0 and V1 path. It's safer to - # put the logic here. - # FIXME: Move this logic to `_process_image_input` when v0 is - # deprecated. Merge this function with `Molmo._get_mm_embeds`. - num_patches_per_image: list[int] = num_patches.tolist() - - embeds_flat = features.new_full( - (sum(num_patches_per_image), *features.shape[1:]), - fill_value=torch.nan, - ) - embeds_flat[embed_is_patch.view(-1)] = features - - return embeds_flat.split(num_patches_per_image) - def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) @@ -472,9 +445,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, return image_features return flatten_2d_lists( - self._get_mm_embeds(*args) for args in zip( + scatter_patch_features(*args) for args in zip( image_features, - image_input["num_patches"], + image_input["num_embeds"], image_input["embed_is_patch"], )) @@ -485,15 +458,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - # Extract the patch tokens - patch_embeddings = json_map_leaves( - lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), - cast(JSONTree[torch.Tensor], multimodal_embeddings), - ) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - cast(NestedTensors, patch_embeddings), + select_patch_features(multimodal_embeddings), self.vision_args.image_token_id, ) return inputs_embeds diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 9a6fac2eec568..f316e7d0ef57e 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Final, Generic, Optional, Protocol, TypeVar, Union +from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast import torch from transformers import PretrainedConfig @@ -9,9 +9,12 @@ from transformers import PretrainedConfig import vllm.envs as envs from vllm.attention.selector import (backend_name_to_enum, get_global_forced_attn_backend) +from vllm.jsontree import JSONTree, json_map_leaves from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform +from .interfaces import MultiModalEmbeddings + logger = init_logger(__name__) _C = TypeVar("_C", bound=PretrainedConfig) @@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs( if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) return torch.cat(hs_pool, dim=-1) + + +def scatter_patch_features( + features: torch.Tensor, + num_embeds: torch.Tensor, + embed_is_patch: torch.Tensor, +) -> tuple[torch.Tensor, ...]: + """ + Scatter the patch features into a contiguous tensor that corresponds + to the embedding tokens defined by the multimodal processor. + + The rest of the values in the tensor are set to NaN so that they + can be filtered out by :func`select_patch_features`. + + Args: + features: The patch features, concatenated across each image. + Shape: `(num_patch, feature_depth)` + num_embeds: The number of image embeddings for each image. + Shape: `(num_images,)` + embed_is_patch: A boolean mask indicating which image embeddings + correspond to patch tokens for each image. + Shape: `(num_images, num_embeds)` + """ + num_embeds_per_image: list[int] = num_embeds.tolist() + + embeds_flat = features.new_full( + (sum(num_embeds_per_image), features.shape[-1]), + fill_value=torch.nan, + ) + embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2) + + return embeds_flat.split(num_embeds_per_image) + + +def select_patch_features( + multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings: + """ + Given the outputs of :func:`scatter_patch_features`, return only + the values that correspond to patch features. + """ + selected_features = json_map_leaves( + lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), + cast(JSONTree[torch.Tensor], multimodal_embeddings), + ) + return cast(MultiModalEmbeddings, selected_features) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 10c53dfb2c66e..b400e2701ac3a 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -511,8 +511,35 @@ def iter_token_matches( start_idx += 1 +def replace_token_matches( + token_ids: list[int], + match_ids: list[int], + new_ids: list[int], +) -> list[int]: + """ + Replace each occurrence of :code:`match_ids` in :code:`token_ids` + with :code:`new_ids`. + + Note that empty matches are ignored. + """ + out_seqs = list[list[int]]() + prev_end_idx = 0 + + for match in iter_token_matches(token_ids, match_ids): + start_idx = match.start_idx + end_idx = match.end_idx + + out_seqs.append(token_ids[prev_end_idx:start_idx]) + out_seqs.append(new_ids) + prev_end_idx = end_idx + + out_seqs.append(token_ids[prev_end_idx:]) + + return flatten_2d_lists(out_seqs) + + @dataclass(repr=False) -class _PromptTargetMatch(ABC): +class PromptTargetMatch(ABC): _origin: BoundPromptUpdate @property @@ -535,7 +562,7 @@ class _PromptTargetMatch(ABC): @dataclass(repr=False) -class _PromptTargetIndexMatch(_PromptTargetMatch): +class _PromptTargetIndexMatch(PromptTargetMatch): match_idx: int @property @@ -548,7 +575,7 @@ class _PromptTargetIndexMatch(_PromptTargetMatch): @dataclass(repr=False) -class _PromptTargetTokenMatch(_PromptTargetMatch): +class _PromptTargetTokenMatch(PromptTargetMatch): match: _TokenMatch @property @@ -561,7 +588,7 @@ class _PromptTargetTokenMatch(_PromptTargetMatch): @dataclass(repr=False) -class _PromptTargetTextMatch(_PromptTargetMatch): +class _PromptTargetTextMatch(PromptTargetMatch): match: re.Match[str] @property @@ -594,7 +621,7 @@ class PlaceholderFeaturesInfo: def find_token_matches( prompt: list[int], prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[_PromptTargetMatch]: +) -> Sequence[PromptTargetMatch]: """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" def get_matches(update: BoundPromptUpdate): @@ -620,7 +647,7 @@ def find_token_matches( def find_text_matches( prompt: str, prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[_PromptTargetMatch]: +) -> Sequence[PromptTargetMatch]: """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" def get_matches(update: BoundPromptUpdate): @@ -645,15 +672,15 @@ def find_text_matches( def _resolve_matches( prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], -) -> list[_PromptTargetMatch]: + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], +) -> list[PromptTargetMatch]: """ Resolve :code:`mm_matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ matches = [m for matches in mm_matches.values() for m in matches] - seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt) + seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) for match in matches: for idx in range(match.start_idx, match.end_idx): @@ -669,7 +696,7 @@ def _resolve_matches( def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> list[_S]: """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" @@ -718,7 +745,7 @@ def _apply_matches( def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> list[int]: """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" @@ -732,7 +759,7 @@ def apply_token_matches( def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> str: """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" @@ -1055,14 +1082,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Given the original multi-modal items for this modality and HF-processed data, output the updates to perform. - Notes: - - You should not assume that HF processor always performs prompt - updates: in :meth:`_apply_hf_processor_missing`, this method - is called on text-only and multimodal-only inputs separately, - instead of passing them in the same call. - - The update information returned by this method is also used to - determine the placeholder token positions for each multi-modal - item. + The information returned by this method is used to update token inputs + which bypass the HF processor. It is also used to update the output of + HF processor if the HF process does not apply prompt updates to text + inputs. + + Moreover, this information is critical to determine the token positions + in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` + for each multi-modal item. """ raise NotImplementedError @@ -1357,6 +1384,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): it = (update.bind(tokenizer) for update in prompt_updates) return dict(full_groupby_modality(it)) + def _apply_token_matches( + self, + prompt: list[int], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], + ) -> list[int]: + return apply_token_matches(prompt, mm_matches, mm_item_counts) + + def _apply_text_matches( + self, + prompt: str, + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], + ) -> str: + return apply_text_matches(prompt, mm_matches, mm_item_counts) + def _apply_prompt_updates( self, token_ids: list[int], @@ -1388,7 +1431,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_match_counts.get(modality, 0) >= item_count for modality, item_count in mm_item_counts.items() ): # yapf: disable - token_ids = apply_token_matches( + token_ids = self._apply_token_matches( token_ids, mm_token_matches, mm_item_counts, @@ -1406,7 +1449,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): modality: find_text_matches(text, updates) for modality, updates in mm_prompt_updates.items() } - text = apply_text_matches( + text = self._apply_text_matches( text, mm_text_matches, mm_item_counts,