From 61c6a5a79664882a8ab1c9af3ff78677911516dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= <54138269+Flechman@users.noreply.github.com> Date: Sat, 15 Mar 2025 14:28:27 +0100 Subject: [PATCH] [VLM] Merged multi-modal processor for Pixtral (#12211) Signed-off-by: remi Signed-off-by: DarkLight1337 Co-authored-by: DarkLight1337 --- examples/offline_inference/pixtral.py | 24 +- .../multimodal/processing/test_common.py | 190 ++++-- vllm/model_executor/models/llava.py | 126 ++-- vllm/model_executor/models/molmo.py | 6 +- vllm/model_executor/models/paligemma.py | 9 +- vllm/model_executor/models/pixtral.py | 588 +++++++++++------- vllm/multimodal/processing.py | 14 +- vllm/transformers_utils/tokenizer.py | 19 +- vllm/utils.py | 2 +- 9 files changed, 620 insertions(+), 358 deletions(-) diff --git a/examples/offline_inference/pixtral.py b/examples/offline_inference/pixtral.py index 760de114508cd..03e6eea891088 100644 --- a/examples/offline_inference/pixtral.py +++ b/examples/offline_inference/pixtral.py @@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams # python demo.py advanced -def run_simple_demo(): +def run_simple_demo(args: argparse.Namespace): model_name = "mistralai/Pixtral-12B-2409" sampling_params = SamplingParams(max_tokens=8192) - # Lower max_num_seqs or max_model_len on low-VRAM GPUs. - llm = LLM(model=model_name, tokenizer_mode="mistral") + # Lower max_model_len and/or max_num_seqs on low-VRAM GPUs. + llm = LLM( + model=model_name, + tokenizer_mode="mistral", + max_model_len=4096, + max_num_seqs=2, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) prompt = "Describe this image in one sentence." image_url = "https://picsum.photos/id/237/200/300" @@ -76,7 +82,7 @@ def run_simple_demo(): print(outputs[0].outputs[0].text) -def run_advanced_demo(): +def run_advanced_demo(args: argparse.Namespace): model_name = "mistralai/Pixtral-12B-2409" max_img_per_msg = 5 max_tokens_per_img = 4096 @@ -87,6 +93,7 @@ def run_advanced_demo(): tokenizer_mode="mistral", limit_mm_per_prompt={"image": max_img_per_msg}, max_model_len=max_img_per_msg * max_tokens_per_img, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) prompt = "Describe the following image." @@ -153,14 +160,19 @@ def main(): help="Specify the demo mode: 'simple' or 'advanced'", ) + parser.add_argument( + '--disable-mm-preprocessor-cache', + action='store_true', + help='If True, disables caching of multi-modal preprocessor/mapper.') + args = parser.parse_args() if args.mode == "simple": print("Running simple demo...") - run_simple_demo() + run_simple_demo(args) elif args.mode == "advanced": print("Running advanced demo...") - run_advanced_demo() + run_advanced_demo(args) if __name__ == "__main__": diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0e0d3711357e4..f761190a8d097 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -2,17 +2,23 @@ import copy from functools import partial -from typing import Optional +from typing import Optional, Union import numpy as np import pytest +from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, + UserMessage) +from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.processing import ProcessingCache -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.inputs import MultiModalInputs +from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache +from vllm.transformers_utils.tokenizer import (MistralTokenizer, + cached_tokenizer_from_config) from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS @@ -85,14 +91,6 @@ def _test_processing_correctness( partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), } - tokenizer_encode_kwargs = {} - if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): - # For some multimodal models, tokenizer will always add bos_token - # at the beginning of prompt by default, causing hf_processor outputs - # incorrect token ids. So we need use `add_special_tokens=False` here - # to leave bos_token to be added by the processor. - tokenizer_encode_kwargs = {"add_special_tokens": False} - for batch_idx in range(num_batches): mm_data = { k: @@ -115,43 +113,131 @@ def _test_processing_correctness( elif len(mm_data[k]) == 1: mm_data[k] = mm_data[k][0] - baseline_result = baseline_processor.apply( - prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) - cached_result = cached_processor.apply( - prompt, - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) + if isinstance(tokenizer, MistralTokenizer): + _test_processing_correctness_mistral( + model_config, + tokenizer, + prompt, + mm_data, + baseline_processor, + cached_processor, + batch_idx, + ignore_mm_keys=ignore_mm_keys, + ) + else: + _test_processing_correctness_hf( + model_config, + tokenizer, + prompt, + mm_data, + baseline_processor, + cached_processor, + batch_idx, + ignore_mm_keys=ignore_mm_keys, + ) - assert _drop_mm_kwargs_keys( - baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( - cached_result, ignore_mm_keys), ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") - baseline_tokenized_result = baseline_processor.apply( - tokenizer.encode(prompt, **tokenizer_encode_kwargs), - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) +def _test_processing_correctness_hf( + model_config: ModelConfig, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt: str, + mm_data: MultiModalDataDict, + baseline_processor: BaseMultiModalProcessor, + cached_processor: BaseMultiModalProcessor, + batch_idx: int, + ignore_mm_keys: Optional[list[str]] = None, +): + if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): + # For some multimodal models, tokenizer will always add bos_token + # at the beginning of prompt by default, causing hf_processor outputs + # incorrect token ids. So we need use `add_special_tokens=False` here + # to leave bos_token to be added by the processor. + token_prompt = tokenizer.encode(prompt, add_special_tokens=False) + else: + token_prompt = tokenizer.encode(prompt) - assert _drop_mm_kwargs_keys( - baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( - baseline_tokenized_result, ignore_mm_keys), ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + baseline_result = baseline_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + cached_result = cached_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) - cached_tokenized_result = cached_processor.apply( - tokenizer.encode(prompt, **tokenizer_encode_kwargs), - mm_data=mm_data, - hf_processor_mm_kwargs={}, - ) + assert _inputs_equal( + baseline_result, + cached_result, + ignore_mm_keys, + ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" - assert _drop_mm_kwargs_keys( - cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys( - cached_tokenized_result, ignore_mm_keys), ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + baseline_tokenized_result = baseline_processor.apply( + token_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert _inputs_equal( + baseline_result, + baseline_tokenized_result, + ignore_mm_keys, + ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" + + cached_tokenized_result = cached_processor.apply( + token_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert _inputs_equal( + cached_result, + cached_tokenized_result, + ignore_mm_keys, + ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" + + +def _test_processing_correctness_mistral( + model_config: ModelConfig, + tokenizer: MistralTokenizer, + prompt: str, + mm_data: MultiModalDataDict, + baseline_processor: BaseMultiModalProcessor, + cached_processor: BaseMultiModalProcessor, + batch_idx: int, + ignore_mm_keys: Optional[list[str]] = None, +): + images = mm_data.get("image", []) + if not isinstance(images, list): + images = [images] + + request = ChatCompletionRequest(messages=[ + UserMessage(content=[ + TextChunk(text=prompt), + *(ImageChunk(image=image) for image in images), + ]), + ]) + res = tokenizer.mistral.encode_chat_completion(request) + token_prompt = res.tokens + + # Mistral chat outputs tokens directly, rather than text prompts + baseline_tokenized_result = baseline_processor.apply( + token_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + cached_tokenized_result = cached_processor.apply( + token_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert _inputs_equal( + baseline_tokenized_result, + cached_tokenized_result, + ignore_mm_keys, + ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" # yapf: disable @@ -173,6 +259,7 @@ def _test_processing_correctness( "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "meta-llama/Llama-3.2-11B-Vision-Instruct", "TIGER-Lab/Mantis-8B-siglip-llama3", + "mistralai/Pixtral-12B-2409", "mistral-community/pixtral-12b", "openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-V-2_6", @@ -241,8 +328,19 @@ def test_processing_correctness_phi3v( ) -def _drop_mm_kwargs_keys(result: dict, - ignore_mm_keys: Optional[list[str]] = None) -> dict: +def _inputs_equal( + a: MultiModalInputs, + b: MultiModalInputs, + ignore_mm_keys: Optional[list[str]] = None, +): + return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys( + b, ignore_mm_keys) + + +def _drop_mm_kwargs_keys( + result: MultiModalInputs, + ignore_mm_keys: Optional[list[str]] = None, +) -> MultiModalInputs: """Drop specified keys from result['mm_kwargs']. This is mainly to avoid doing exact match of audio_features in ultravox. diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 42bf6a5b2979a..3a8d184528d8b 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict): in which case the data is passed as a list instead of a batched tensor. """ - feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image features correspond - to patch tokens. - - Shape: `(batch_size, num_crops, num_patch)` - """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ A boolean mask indicating which image embeddings correspond to patch tokens. - Shape: `(batch_size, num_embeds)` + Shape: `(batch_size, num_images, num_embeds)` """ - num_crops: Union[torch.Tensor, list[torch.Tensor]] + num_patches: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size, num_images)`""" @@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor( image_height=pixel_value.shape[-2], ) for pixel_value in processed_outputs["pixel_values"] ] - num_crops = torch.tensor([(ncols + 1) * nrows - for ncols, nrows in tile_sizes]) + num_patches = torch.tensor([(ncols + 1) * nrows + for ncols, nrows in tile_sizes]) # Each image may result to masks of different sizes, so we need to - # flatten the list and later use `num_crops` to get per-image masks. - embed_is_patch = torch.tensor( - flatten_2d_lists([([True] * ncols + [False]) * nrows - for ncols, nrows in tile_sizes])) - processed_outputs["num_crops"] = num_crops + # later use `num_patches` to get per-image masks. + embed_is_patch = [ + torch.tensor(([True] * ncols + [False]) * nrows) + for ncols, nrows in tile_sizes + ] + processed_outputs["num_patches"] = num_patches processed_outputs["embed_is_patch"] = embed_is_patch - processed_outputs["feat_is_patch"] = embed_is_patch return processed_outputs @@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1) return dict( - feat_is_patch=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops), - embed_is_patch=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops), - num_crops=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"), + num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): f"Got type: {type(pixel_values)}") if self.config.vision_config.model_type == "pixtral": - feat_is_patch = kwargs.pop("feat_is_patch") - if not isinstance(feat_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of feat_is_patch. " - f"Got type: {type(feat_is_patch)}") - embed_is_patch = kwargs.pop("embed_is_patch") if not isinstance(embed_is_patch, (torch.Tensor, list)): raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") - num_crops = kwargs.pop("num_crops") - if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_crops. " - f"Got type: {type(num_crops)}") + num_patches = kwargs.pop("num_patches") + if not isinstance(num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_patches. " + f"Got type: {type(num_patches)}") return PixtralHFImagePixelInputs( type="pixel_values_pixtral", pixel_values=flatten_bn(pixel_values), - feat_is_patch=feat_is_patch, embed_is_patch=embed_is_patch, - num_crops=num_crops, + num_patches=num_patches, ) return LlavaImagePixelInputs( @@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], pixel_values: Union[torch.Tensor, list[torch.Tensor]], - ) -> torch.Tensor: - + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower image_features = vision_tower(pixel_values) - return self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + def select_features(leaf: torch.Tensor): + return self._select_image_features( + leaf, + strategy=self.config.vision_feature_select_strategy, + ) + + return cast( + Union[torch.Tensor, tuple[torch.Tensor, ...]], + json_map_leaves(select_features, image_features), ) def _process_image_pixels( self, inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] @@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def _get_mm_embeds( self, - features: torch.Tensor, # Shape: (num_crop, num_patch, d) - feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) - num_crops: torch.Tensor, # Shape: (num_images,) - embed_is_patch: torch.Tensor, # Shape: (num_embeds,) - ) -> list[torch.Tensor]: + features: torch.Tensor, # Shape: (num_patch, d) + num_patches: torch.Tensor, # Shape: (num_images,) + embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds) + ) -> tuple[torch.Tensor, ...]: """Scatter the patch features into a contiguous tensor that corresponds to the embedding tokens defined by the multimodal processor. Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. """ - - # Insert columns of nan values according to `feat_is_patch`. This work + # Insert columns of nan values according to `embed_is_patch`. This work # ideally should be done in `_process_image_input`, but # `_process_image_input` is used in both V0 and V1 path. It's safer to # put the logic here. # FIXME: Move this logic to `_process_image_input` when v0 is # deprecated. Merge this function with `Molmo._get_mm_embeds`. - feat_is_patch = feat_is_patch.view(-1) - embed_is_patch = embed_is_patch.view(-1) - expanded_embedding = torch.full( - (sum(num_crops), *features.shape[1:]), - torch.nan, - dtype=features.dtype).to(features.device) - expanded_embedding[feat_is_patch] = features + num_patches_per_image: list[int] = num_patches.tolist() - num_crops_per_image = num_crops.tolist() - feats_per_image = expanded_embedding.split(num_crops_per_image) - f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) + embeds_flat = features.new_full( + (sum(num_patches_per_image), *features.shape[1:]), + fill_value=torch.nan, + ) + embeds_flat[embed_is_patch.view(-1)] = features - embed_dim = expanded_embedding.shape[-1] - num_embeds = embed_is_patch.shape[0] - - embeds_in_batch = list[torch.Tensor]() - for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image): - embeds = feats.new_full((num_embeds, embed_dim), torch.nan) - embeds[embed_is_patch] = feats[f_is_patch] - embeds_in_batch.append(embeds) - - return embeds_in_batch + return embeds_flat.split(num_patches_per_image) def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: @@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # The path is used for pixtral (V0 only) and llava (V0/V1) return vision_embeddings - nested_emb = [ + return flatten_2d_lists( self._get_mm_embeds(*args) for args in zip( - vision_embeddings, image_input["feat_is_patch"], - image_input["num_crops"], image_input["embed_is_patch"]) - ] - return flatten_2d_lists(nested_emb) + vision_embeddings, + image_input["num_patches"], + image_input["embed_is_patch"], + )) def get_input_embeddings( self, @@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ) inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, cast(NestedTensors, - patch_embeddings), - self.config.image_token_index) + input_ids, + inputs_embeds, + cast(NestedTensors, patch_embeddings), + self.config.image_token_index, + ) return inputs_embeds def forward( diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index e709b08815eaf..c7f6cf461d523 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, image_features = self._process_image_input(image_input) - nested_embeds = [ + return flatten_2d_lists( self._get_mm_embeds(*args) for args in zip( image_features, image_input["feat_is_patch"], image_input["num_crops"], image_input["embed_is_patch"], - ) - ] - return flatten_2d_lists(nested_embeds) + )) def get_input_embeddings( self, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 88a6226d21448..8a773607ce4ed 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - -from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, Set, Tuple, TypedDict, Union import torch from torch import nn @@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptReplacement, + PromptInsertion, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor( mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 2e71390623fdf..fff630056e405 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,26 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 import math -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import List, Optional, Set, Tuple, Union +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast import torch import torch.nn as nn import torch.nn.functional as F from mistral_common.protocol.instruct.messages import ImageChunk +from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image -from transformers import PixtralVisionConfig +from transformers import PixtralVisionConfig, TensorType +from transformers.image_utils import ImageInput from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens) from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) +from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.jsontree import JSONTree, json_map_leaves from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -31,13 +33,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import PlaceholderRange -from vllm.multimodal.utils import consecutive_placeholder_ranges -from vllm.sequence import IntermediateTensors, SequenceData -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import (MistralTokenizer, + cached_tokenizer_from_config) +from vllm.utils import flatten_2d_lists from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (init_vllm_registered_model, maybe_prefix, +from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs @@ -48,132 +57,275 @@ except ImportError: USE_XFORMERS_OPS = False -def get_max_pixtral_image_tokens(ctx: InputContext): - tokenizer = cached_tokenizer_from_config(ctx.model_config) - mm_encoder = tokenizer.instruct.mm_encoder +class PixtralImagePixelInputs(TypedDict): + type: Literal["pixel_values"] - image_config = mm_encoder.mm_config if hasattr( - mm_encoder, "mm_config") else mm_encoder.image_config - - max_image_size = image_config.max_image_size - image_patch_size = image_config.image_patch_size - - return ((max_image_size // image_patch_size)**2) - - -def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - tokenizer = cached_tokenizer_from_config(ctx.model_config) - - mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - image_token_id = mm_encoder.special_ids.img - - mm_config = ctx.get_mm_config() - num_images = mm_config.get_limit_per_prompt("image") - - # dummy size - size = 256 - image = Image.new("RGB", (size, size), color=0) - - encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image)) - image_feature_size = len(encoding.tokens) - num_image_tokens = image_feature_size * num_images - seq_data = SequenceData.from_prompt_token_counts( - (image_token_id, num_image_tokens), - (0, seq_len - num_image_tokens), - ) - - mm_data = {"image": num_images * [image]} - mm_placeholders = { - "image": - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - return DummyData(seq_data, mm_data, mm_placeholders) - - -def input_mapper_for_pixtral(ctx: InputContext, - data: object) -> MultiModalKwargs: - """Maps the input data to its MultiModalKwargs (if any). - - Args: - ctx: Context of the loaded model. - data: data potentially containing PIL images to be processed - and mapped to `images`. - - Returns: - MultiModalKwargs containing the stacked normalized images tensor or - image embeddings. + images: Union[torch.Tensor, list[torch.Tensor]] """ - tokenizer = cached_tokenizer_from_config(ctx.model_config) + Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - data_list = data if isinstance(data, list) else [data] + The result of stacking :attr:`ImageEncoding.tokens` from each prompt. + """ - images = [] - image_tokens_list = [] - for image_data in data_list: - image = ImageChunk(image=image_data) - encoding = tokenizer.instruct.mm_encoder(image) - image = torch.from_numpy(encoding.image).to(dtype=torch.float16) - images.append(image) - image_tokens_list.append(encoding.tokens) + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_images, num_embeds)` + """ - image_tokens = torch.tensor([ - token_id for image_tokens in image_tokens_list - for token_id in image_tokens - ]) - return MultiModalKwargs({"images": images, "image_tokens": image_tokens}) + num_patches: Union[torch.Tensor, list[torch.Tensor]] + """Shape: `(batch_size, num_images)`""" -def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs +class PixtralProcessorAdapter: + """ + Provide a HF-compatible interface for + :class:`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`. + """ - prompt_token_ids = inputs.get("prompt_token_ids") - prompt = inputs.get("prompt") - tokenizer = cached_tokenizer_from_config(ctx.model_config) + def __init__(self, tokenizer: MistralTokenizer) -> None: + super().__init__() - mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - image_token_id = mm_encoder.special_ids.img - image_break_id = mm_encoder.special_ids.img_break - image_end_id = mm_encoder.special_ids.img_end + self.tokenizer = tokenizer - if image_token_id not in inputs['prompt_token_ids']: - raise ValueError( - f"You've passed {inputs=} without {image_token_id=}" - " Make sure to process your input via mistral_common's" - " tokenizer or pass a chat completion request. For more" - " For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + @property + def image_processor(self) -> ImageEncoder: + image_encoder = self.tokenizer.instruct.mm_encoder + assert isinstance(image_encoder, ImageEncoder) + return image_encoder - # Get precise tracking of placeholder positions - placeholder_ranges = [] - curr_offset = -1 - curr_length = 0 - for i in range(len(prompt_token_ids)): - if prompt_token_ids[i] in (image_token_id, image_break_id): - if curr_offset < 0: - curr_offset = i - curr_length += 1 - elif prompt_token_ids[i] == image_end_id: - curr_length += 1 - placeholder_ranges.append( - PlaceholderRange(offset=curr_offset, length=curr_length)) - curr_offset = -1 - curr_length = 0 - else: - pass - return token_inputs(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": placeholder_ranges}) + @cached_property + def image_break_id(self) -> int: + return self.image_processor.special_ids.img_break + + @cached_property + def image_token_id(self) -> int: + return self.image_processor.special_ids.img + + @cached_property + def image_end_id(self) -> int: + return self.image_processor.special_ids.img_end + + @cached_property + def image_size(self) -> int: + return self.image_processor.mm_config.max_image_size + + @cached_property + def patch_size(self) -> int: + return self.image_processor.mm_config.image_patch_size + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> Mapping[str, NestedTensors]: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if not images: + input_ids = self.tokenizer(text).input_ids + + return {"input_ids": torch.tensor(input_ids)} + + # Allow dummy text, which is used for profiling as well as token inputs + if any(len(t) > 0 for t in text): + raise ValueError( + "You've passed text inputs instead of token inputs. " + "Make sure to process your input via `mistral_common`'s " + "tokenizer or pass a chat completion request. " + "For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.") + + image_token_id = self.image_token_id + + images_processed = list[torch.Tensor]() + images_tokens = list[torch.Tensor]() + images_embed_is_patch = list[torch.Tensor]() + images_num_patches = list[int]() + + for image in images: + image_inputs = self.image_processor(ImageChunk(image=image)) + + image_processed = torch.tensor(image_inputs.image) + image_tokens = torch.tensor(image_inputs.tokens) + + images_processed.append(image_processed) + images_tokens.append(image_tokens) + images_embed_is_patch.append(image_tokens == image_token_id) + images_num_patches.append(len(image_tokens)) + + return { + "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), + "images": images_processed, + "embed_is_patch": images_embed_is_patch, + "num_patches": torch.tensor(images_num_patches), + } -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) -@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) +class PixtralProcessingInfo(BaseProcessingInfo): + + def get_tokenizer(self) -> MistralTokenizer: + tokenizer = cached_tokenizer_from_config(self.ctx.model_config) + if not isinstance(tokenizer, MistralTokenizer): + raise ValueError("This model requires `--tokenizer-mode mistral`") + + return tokenizer + + def get_hf_processor(self) -> PixtralProcessorAdapter: + return PixtralProcessorAdapter(self.get_tokenizer()) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_vision_config( + self, + processor: Optional[PixtralProcessorAdapter] = None, + ): + if processor is None: + processor = self.get_hf_processor() + + return PixtralVisionConfig( + image_size=processor.image_size, + patch_size=processor.patch_size, + ) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[PixtralProcessorAdapter] = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + ncols, nrows = processor.image_processor._image_to_num_tokens( + Image.new("RGB", (image_width, image_height))) + + return (ncols + 1) * nrows + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_hf_processor().image_processor + max_image_size = image_processor.mm_config.max_image_size + + return ImageSize(width=max_image_size, height=max_image_size) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] + ): + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + images=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + num_patches=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_break_id = processor.image_break_id + image_token_id = processor.image_token_id + image_end_id = processor.image_end_id + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = processor.image_processor._image_to_num_tokens( + Image.new("RGB", (image_size.width, image_size.height))) + + tokens = ([image_token_id] * ncols + [image_break_id]) * nrows + tokens[-1] = image_end_id + + return tokens + + return [ + PromptReplacement( + modality="image", + target="", # Never match the prompt (see below note) + replacement=get_replacement, + ), + ] + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + # NOTE: The tokens are already inserted by the chat template + return prompt_ids, mm_kwargs, True + + +@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, + info=PixtralProcessingInfo, + dummy_inputs=PixtralDummyInputsBuilder) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @@ -191,13 +343,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if key in dataclass_fields } - if not ("image_break_token_id" in vision_args - and "image_end_token_id" in vision_args): - raise ValueError( - "'image_break_token_id' and 'image_end_token_id' not found " - "in the vision_encoder arguments. Please download the latest " - "version of 'params.json' from the model repository.") - self.vision_args = VisionEncoderArgs(**vision_args) # init MistralForCausalLM @@ -221,36 +366,92 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, return get_sampler() + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: + images = kwargs.pop("images", None) + if images is None: + return None + + if not isinstance(images, (torch.Tensor, list)): + raise ValueError("Incorrect type of images. " + f"Got type: {type(images)}") + + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + num_patches = kwargs.pop("num_patches") + if not isinstance(num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_patches. " + f"Got type: {type(num_patches)}") + + return PixtralImagePixelInputs( + type="pixel_values", + images=flatten_bn(images), + embed_is_patch=embed_is_patch, + num_patches=num_patches, + ) + + def _process_image_input( + self, + image_input: PixtralImagePixelInputs, + ) -> tuple[torch.Tensor, ...]: + images = image_input["images"] + + image_features = self.vision_encoder(images) + feature_sizes = [ + image_feature.shape[0] for image_feature in image_features + ] + + image_embeds = self.vision_language_adapter(torch.cat(image_features)) + image_embeds = torch.split(image_embeds, feature_sizes) + return image_embeds + + def _get_mm_embeds( + self, + features: torch.Tensor, # Shape: (num_patch, d) + num_patches: torch.Tensor, # Shape: (num_images,) + embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds) + ) -> tuple[torch.Tensor, ...]: + """Scatter the patch features into a contiguous tensor that corresponds + to the embedding tokens defined by the multimodal processor. + + Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. + """ + # Insert columns of nan values according to `embed_is_patch`. This work + # ideally should be done in `_process_image_input`, but + # `_process_image_input` is used in both V0 and V1 path. It's safer to + # put the logic here. + # FIXME: Move this logic to `_process_image_input` when v0 is + # deprecated. Merge this function with `Molmo._get_mm_embeds`. + num_patches_per_image: list[int] = num_patches.tolist() + + embeds_flat = features.new_full( + (sum(num_patches_per_image), *features.shape[1:]), + fill_value=torch.nan, + ) + embeds_flat[embed_is_patch.view(-1)] = features + + return embeds_flat.split(num_patches_per_image) + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - image_input, image_tokens = self._parse_and_validate_image_input( - **kwargs) + image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None - vision_embeddings = self._process_image_input(image_input) + image_features = self._process_image_input(image_input) - # NOTE: We patch the outputs of the vision encoder with embeddings - # from `[IMG_BREAK]` and `[IMG_END]` tokens. - image_embeds = self.language_model.get_input_embeddings(image_tokens) - image_token_mask = image_tokens == self.vision_args.image_token_id - image_embeds[image_token_mask] = vision_embeddings + if kwargs.get("v0_path", False): + return image_features - # NOTE: Image embeddings are split into separate tensors for each image - # by the indices of `[IMG_END]` token. - image_end_mask = image_tokens == self.vision_args.image_end_token_id - split_indices = torch.where(image_end_mask)[0] + 1 - if len(split_indices) <= 1: - # Do not split, return as tensor of shape [1, fs, hs] - return image_embeds.unsqueeze(0) - - # If the last split index is the last index in image_tokens, we - # ignore it to avoid empty split tensor - if split_indices[-1] == len(image_tokens): - split_indices = split_indices[:-1] - - image_embeds = image_embeds.tensor_split(split_indices.cpu()) - return image_embeds + return flatten_2d_lists( + self._get_mm_embeds(*args) for args in zip( + image_features, + image_input["num_patches"], + image_input["embed_is_patch"], + )) def get_input_embeddings( self, @@ -259,12 +460,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: + # Extract the patch tokens + patch_embeddings = json_map_leaves( + lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), + cast(JSONTree[torch.Tensor], multimodal_embeddings), + ) inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, [ - self.vision_args.image_token_id, - self.vision_args.image_break_token_id, - self.vision_args.image_end_token_id, - ]) + input_ids, + inputs_embeds, + cast(NestedTensors, patch_embeddings), + self.vision_args.image_token_id, + ) return inputs_embeds def forward( @@ -275,14 +481,14 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - """Run forward pass for pixtral. - """ + """Run forward pass for pixtral.""" if intermediate_tensors is not None: inputs_embeds = None # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) @@ -295,47 +501,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, return hidden_states - def _parse_and_validate_image_input( - self, - images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], - torch.Tensor]] = None, - image_tokens: Optional[torch.Tensor] = None, - ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: - if images is None: - return None, None - - if isinstance(images, torch.Tensor): - # if passed as batch take all images - N, B, C, W, H = images.shape - images = images.reshape(N * B, C, W, H) - images = [images[i] for i in range(images.size(0))] - elif isinstance(images, list): - # if passed as list flatten lists of tensors - flatten_images = [] - for imgs_per_req in images: - imgs_per_req = [ - imgs_per_req[i] for i in range(imgs_per_req.size(0)) - ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req - - flatten_images.extend(imgs_per_req) - - images = flatten_images - - if isinstance(image_tokens, torch.Tensor): - # image_tokens are batched - image_tokens = image_tokens.flatten() - elif isinstance(image_tokens, list): - # image_tokens are of different lengths thus passed as a list - image_tokens = torch.cat(image_tokens) - - assert image_tokens.dim() == 1 - - return images, image_tokens - - def _process_image_input(self, - image_input: List[torch.Tensor]) -> torch.Tensor: - return self.vision_language_adapter(self.vision_encoder(image_input)) - def compute_logits( self, hidden_states: torch.Tensor, @@ -400,8 +565,6 @@ class VisionEncoderArgs: num_attention_heads: int rope_theta: float # for rope-2D image_token_id: int - image_break_token_id: int - image_end_token_id: int adapter_bias: bool = True @@ -637,9 +800,13 @@ class VisionTransformer(nn.Module): self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] + patch_embeds = [ + p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list + ] + embed_sizes = [p.shape[1] for p in patch_embeds] + # flatten to a single sequence - patch_embeds = torch.cat( - [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = torch.cat(patch_embeds, dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings @@ -655,8 +822,8 @@ class VisionTransformer(nn.Module): "with the Mistral format") out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) - # remove batch dimension of the single sequence - return out.squeeze(0) + # squeeze dim 0 and split into separate tensors for each image + return torch.split(out.squeeze(0), embed_sizes) class VisionLanguageAdapter(nn.Module): @@ -978,9 +1145,9 @@ class PixtralHFVisionModel(nn.Module): def forward( self, - pixel_values: List[torch.Tensor], + pixel_values: list[torch.Tensor], feature_sample_layers: Optional[list[int]] = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, ...]: """ Args: pixel_values: Each image to be processed will be a separate tensor @@ -1039,8 +1206,7 @@ class PixtralHFVisionModel(nn.Module): self.config.num_hidden_layers) # squeeze dim 0 and split into separate tensors for each image - out = torch.split(torch.squeeze(out), embed_sizes) - return out + return torch.split(out.squeeze(0), embed_sizes) # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index cdbbed27a5218..10c53dfb2c66e 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -77,7 +77,9 @@ class PromptIndexTargets: else: if isinstance(prefix, str): # Make both `list[int]` - prefix = encode_tokens(tokenizer, prefix) + prefix = encode_tokens(tokenizer, + prefix, + add_special_tokens=False) match_idx = len(prefix) return match_idx if prompt[:match_idx] == prefix else None @@ -318,7 +320,7 @@ def _cached_encode( tokenizer: AnyTokenizer, text: str, *, - add_special_tokens: bool = False, + add_special_tokens: Optional[bool] = None, ) -> list[int]: return encode_tokens(tokenizer, text, @@ -330,7 +332,7 @@ def _cached_decode( tokenizer: AnyTokenizer, token_ids: tuple[int, ...], *, - skip_special_tokens: bool = False, + skip_special_tokens: Optional[bool] = None, ) -> str: return decode_tokens(tokenizer, list(token_ids), @@ -395,7 +397,9 @@ class _BoundPromptSequence: def token_ids(self) -> list[int]: if self._token_ids is None: assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, self._text) + self._token_ids = _cached_encode(self.tokenizer, + self._text, + add_special_tokens=False) return self._token_ids @@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptUpdate]: + ) -> Sequence[PromptUpdate]: """ Given the original multi-modal items for this modality and HF-processed data, output the updates to perform. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2c34f2f5d44d5..1bfb50328338f 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -34,13 +34,20 @@ def decode_tokens( tokenizer: AnyTokenizer, token_ids: list[int], *, - skip_special_tokens: bool = False, + skip_special_tokens: Optional[bool] = None, ) -> str: """ Backend-agnostic equivalent of HF's - :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. + :code:`tokenizer.decode(token_ids, ...)`. + + :code:`skip_special_tokens=None` means to use the backend's default + settings. """ - return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + if skip_special_tokens is not None: + return tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) + + return tokenizer.decode(token_ids) def encode_tokens( @@ -51,10 +58,14 @@ def encode_tokens( ) -> list[int]: """ Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=...)`. + :code:`tokenizer.encode(text, ...)`. + + :code:`add_special_tokens=None` means to use the backend's default + settings. """ if add_special_tokens is not None: return tokenizer.encode(text, add_special_tokens=add_special_tokens) + return tokenizer.encode(text) diff --git a/vllm/utils.py b/vllm/utils.py index 632b3666e959c..79787303af5bc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -845,7 +845,7 @@ def is_list_of( assert_never(check) -def flatten_2d_lists(lists: list[list[T]]) -> list[T]: +def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: """Flatten a list of lists to a single list.""" return [item for sublist in lists for item in sublist]