From 5994430b8433b89e07bf55b266b449c1c3d6d3cd Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 25 Mar 2025 18:27:57 +0800 Subject: [PATCH] [Misc] Remove redundant `num_embeds` (#15443) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/gemma3_mm.py | 16 -------------- vllm/model_executor/models/internvl.py | 14 ------------ vllm/model_executor/models/llava.py | 16 -------------- vllm/model_executor/models/pixtral.py | 14 ------------ vllm/model_executor/models/vision.py | 29 +++++++++++++++++++++---- 5 files changed, 25 insertions(+), 64 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index d843232ca1b6b..63d3ccbf54bc2 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -63,9 +63,6 @@ class Gemma3ImagePixelInputs(TypedDict): Shape: `(batch_size, num_images, num_embeds)` """ - num_embeds: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size, num_images)`""" - Gemma3ImageInputs = Gemma3ImagePixelInputs @@ -317,11 +314,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): tokenizer.encode(image_repl, add_special_tokens=False) for image_repl in image_repl_features ] - num_embeds = [ - len(image_repl_feature_tokens) - for image_repl_feature_tokens in image_repls_feature_tokens - ] - processed_outputs["num_embeds"] = torch.tensor(num_embeds) vocab = tokenizer.get_vocab() image_token_id = vocab[tokenizer.image_token] @@ -354,7 +346,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): "image", num_crops + 1), num_crops=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"), - num_embeds=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -583,7 +574,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, pixel_values = kwargs.pop("pixel_values", None) num_crops = kwargs.pop("num_crops", None) embed_is_patch = kwargs.pop("embed_is_patch", None) - num_embeds = kwargs.pop("num_embeds", None) image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Gemma3 does not support image_embeds." if pixel_values is None: @@ -601,10 +591,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") - if not isinstance(num_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_embeds. " - f"Got type: {type(num_embeds)}") - pixel_values = flatten_bn(pixel_values, concat=True) num_crops = flatten_bn(num_crops, concat=True) @@ -613,7 +599,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, pixel_values=self._validate_pixel_values(pixel_values), num_patches=num_crops + 1, embed_is_patch=embed_is_patch, - num_embeds=num_embeds, ) def _image_pixels_to_features( @@ -656,7 +641,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return flatten_2d_lists( scatter_patch_features(*args) for args in zip( image_features, - image_input["num_embeds"], image_input["embed_is_patch"], )) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e8ec91736d58f..e1aa371610353 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -69,9 +69,6 @@ class InternVLImagePixelInputs(TypedDict): Shape: `(batch_size, num_images, num_embeds)` """ - num_embeds: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size, num_images)`""" - class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -426,7 +423,6 @@ class BaseInternVLProcessor(ABC): tokenizer = self.tokenizer image_token_id = self.image_token_id - num_embeds = list[int]() embed_is_patch = list[torch.Tensor]() for pixel_values in pixel_values_lst: @@ -438,11 +434,9 @@ class BaseInternVLProcessor(ABC): add_special_tokens=False) text = [t.replace('', image_repl.full, 1) for t in text] - num_embeds.append(len(feature_tokens)) embed_is_patch.append( torch.tensor(feature_tokens) == image_token_id) - image_inputs["num_embeds"] = torch.tensor(num_embeds) image_inputs["embed_is_patch"] = embed_is_patch text_inputs = self.tokenizer(text) @@ -607,7 +601,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): "image", image_num_patches), image_num_patches=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"), - num_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -840,7 +833,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) embed_is_patch = kwargs.pop("embed_is_patch", None) - num_embeds = kwargs.pop("num_embeds", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values_flat is None and image_embeds is None: @@ -873,10 +865,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") - if not isinstance(num_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_embeds. " - f"Got type: {type(num_embeds)}") - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) @@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): pixel_values_flat), num_patches=image_num_patches, embed_is_patch=embed_is_patch, - num_embeds=num_embeds, ) raise AssertionError("This line should be unreachable.") @@ -941,7 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): return flatten_2d_lists( scatter_patch_features(*args) for args in zip( image_features, - image_input["num_embeds"], image_input["embed_is_patch"], )) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 40accfffe4f9d..d1014067d9d7c 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict): Shape: `(batch_size, num_images, num_embeds)` """ - num_embeds: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size, num_images)`""" - class LlavaImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -358,15 +355,10 @@ class PixtralHFMultiModalProcessor( image_height=pixel_value.shape[-2], ) for pixel_value in processed_outputs["pixel_values"] ] - num_embeds = torch.tensor([(ncols + 1) * nrows - for ncols, nrows in tile_sizes]) - # Each image may result to masks of different sizes, so we need to - # later use `num_embeds` to get per-image masks. embed_is_patch = [ torch.tensor(([True] * ncols + [False]) * nrows) for ncols, nrows in tile_sizes ] - processed_outputs["num_embeds"] = num_embeds processed_outputs["embed_is_patch"] = embed_is_patch return processed_outputs @@ -378,7 +370,6 @@ class PixtralHFMultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), - num_embeds=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -627,16 +618,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") - num_embeds = kwargs.pop("num_embeds") - if not isinstance(num_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_embeds. " - f"Got type: {type(num_embeds)}") - return PixtralHFImagePixelInputs( type="pixel_values_pixtral", pixel_values=flatten_bn(pixel_values), embed_is_patch=embed_is_patch, - num_embeds=num_embeds, ) return LlavaImagePixelInputs( @@ -738,7 +723,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): return flatten_2d_lists( scatter_patch_features(*args) for args in zip( vision_embeddings, - image_input["num_embeds"], image_input["embed_is_patch"], )) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 5da69ce7fa061..a3ad360961243 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict): Shape: `(batch_size, num_images, num_embeds)` """ - num_embeds: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size, num_images)`""" - class PixtralProcessorAdapter: """ @@ -153,7 +150,6 @@ class PixtralProcessorAdapter: images_processed = list[torch.Tensor]() images_tokens = list[torch.Tensor]() images_embed_is_patch = list[torch.Tensor]() - images_num_embeds = list[int]() for image in images: image_inputs = self.image_processor(ImageChunk(image=image)) @@ -163,13 +159,11 @@ class PixtralProcessorAdapter: images_processed.append(image_processed) images_tokens.append(image_tokens) images_embed_is_patch.append(image_tokens == image_token_id) - images_num_embeds.append(len(image_tokens)) return { "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), "images": images_processed, "embed_is_patch": images_embed_is_patch, - "num_embeds": torch.tensor(images_num_embeds), } @@ -273,7 +267,6 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] return dict( images=MultiModalFieldConfig.batched("image"), embed_is_patch=MultiModalFieldConfig.batched("image"), - num_embeds=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -394,16 +387,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") - num_embeds = kwargs.pop("num_embeds") - if not isinstance(num_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_embeds. " - f"Got type: {type(num_embeds)}") - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), embed_is_patch=embed_is_patch, - num_embeds=num_embeds, ) def _process_image_input( @@ -447,7 +434,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, return flatten_2d_lists( scatter_patch_features(*args) for args in zip( image_features, - image_input["num_embeds"], image_input["embed_is_patch"], )) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index f316e7d0ef57e..250b0ee3c2a1b 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -155,7 +155,6 @@ def resolve_visual_encoder_outputs( def scatter_patch_features( features: torch.Tensor, - num_embeds: torch.Tensor, embed_is_patch: torch.Tensor, ) -> tuple[torch.Tensor, ...]: """ @@ -168,13 +167,35 @@ def scatter_patch_features( Args: features: The patch features, concatenated across each image. Shape: `(num_patch, feature_depth)` - num_embeds: The number of image embeddings for each image. - Shape: `(num_images,)` embed_is_patch: A boolean mask indicating which image embeddings correspond to patch tokens for each image. Shape: `(num_images, num_embeds)` + + Note: + The original code only considers patch tokens as feature + tokens, but our processor considers all image-related tokens + as feature tokens because the feature tokens need to be + consecutive in `input_ids`. + + Example: + A simplified example for one image: + + .. code-block:: + + Embedding tokens (from HF processor): + [ ] + + embed_is_patch (from HF processor): + [ False True True False True True False False ] + + Encoder outputs (from model): + [ p1 p2 p3 p4 ] + + The resulting embedding tensor is: + [ nan p1 p2 nan p3 p4 nan nan ] """ - num_embeds_per_image: list[int] = num_embeds.tolist() + num_images, num_embeds = embed_is_patch.shape + num_embeds_per_image = [num_embeds] * num_images embeds_flat = features.new_full( (sum(num_embeds_per_image), features.shape[-1]),