Allow Gemma3 to take image embeddings (#28483)

Signed-off-by: tingtinggithub <streamttt@gmail.com>
This commit is contained in:
tingtinggithub 2025-11-15 04:18:08 -08:00 committed by GitHub
parent f36292dbee
commit cb15ee28db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 29 deletions

View File

@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ | | `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | | `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal, TypeAlias
import torch import torch
from torch import nn from torch import nn
@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
) )
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
num_patches: Annotated[torch.Tensor, TensorShape("bn")] num_patches: Annotated[torch.Tensor, TensorShape("bn")]
Gemma3ImageInputs = Gemma3ImagePixelInputs class Gemma3ImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] = "image_embeds"
image_embeds: Annotated[
torch.Tensor,
TensorShape("ni", "nf", "hs"),
]
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
class Gemma3ProcessingInfo(BaseProcessingInfo): class Gemma3ProcessingInfo(BaseProcessingInfo):
@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_image_repl( def get_image_repl(
self, self,
*, *,
image_width: int, image_width: int | None,
image_height: int, image_height: int | None,
num_crops: int | None = None,
processor: Gemma3Processor | None, processor: Gemma3Processor | None,
) -> PromptUpdateDetails[str]: ) -> PromptUpdateDetails[str]:
if processor is None: if processor is None:
@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
boi_token = processor.boi_token boi_token = processor.boi_token
num_crops = self.get_num_crops( if num_crops is None:
image_width=image_width, assert image_width is not None and image_height is not None
image_height=image_height, num_crops = self.get_num_crops(
processor=processor, image_width=image_width,
) image_height=image_height,
processor=processor,
)
if num_crops == 0: if num_crops == 0:
image_text = boi_token image_text = boi_token
@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
return dict( return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_token = hf_processor.boi_token image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int): def get_replacement_gemma3(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
# For image embedding inputs, only support no crops cases
# since it's not supported in hf processor anyway
return self.info.get_image_repl(
image_width=None,
image_height=None,
num_crops=0,
processor=hf_processor,
)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
return self.info.get_image_repl( return self.info.get_image_repl(
@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None) num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
image_size = self.config.vision_config.image_size if pixel_values is not None:
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs( return Gemma3ImagePixelInputs(
pixel_values=pixel_values, pixel_values=pixel_values,
num_patches=num_patches, num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size}, resolve_bindings={"h": image_size, "w": image_size},
) )
elif image_embeds is not None:
return Gemma3ImageEmbeddingInputs(
image_embeds=image_embeds,
type="image_embeds",
)
def _image_pixels_to_features( def _image_pixels_to_features(
self, self,
@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
def _process_image_input( def _process_image_input(
self, self,
image_input: Gemma3ImageInputs, image_input: Gemma3ImageInputs,
) -> list[torch.Tensor]: ) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"]
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]

View File

@ -359,8 +359,9 @@ class MultiModalDataParser:
) )
self.video_needs_metadata = video_needs_metadata self.video_needs_metadata = video_needs_metadata
def _is_embeddings( @classmethod
self, data: object def is_embeddings(
cls, data: object
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]: ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.ndim == 3 return data.ndim == 3
@ -420,7 +421,7 @@ class MultiModalDataParser:
): ):
return None return None
if self._is_embeddings(data): if self.is_embeddings(data):
return AudioEmbeddingItems(data) return AudioEmbeddingItems(data)
data_items: list[AudioItem] data_items: list[AudioItem]
@ -458,7 +459,7 @@ class MultiModalDataParser:
if self._is_empty(data): if self._is_empty(data):
return None return None
if self._is_embeddings(data): if self.is_embeddings(data):
return ImageEmbeddingItems(data) return ImageEmbeddingItems(data)
if ( if (
@ -484,7 +485,7 @@ class MultiModalDataParser:
if self._is_empty(data): if self._is_empty(data):
return None return None
if self._is_embeddings(data): if self.is_embeddings(data):
return VideoEmbeddingItems(data) return VideoEmbeddingItems(data)
data_items: list[VideoItem] data_items: list[VideoItem]

View File

@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -340,7 +341,12 @@ class Processor:
mm_uuids: dict[str, list[str | None] | str] = {} mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items(): for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1 # Hash each item for embedding inputs.
n = (
len(data)
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
else 1
)
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids return mm_uuids