Allow Gemma3 to take image embeddings (#28483)

Signed-off-by: tingtinggithub <streamttt@gmail.com>
This commit is contained in:
tingtinggithub 2025-11-15 04:18:08 -08:00 committed by GitHub
parent f36292dbee
commit cb15ee28db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 29 deletions

View File

@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, TypeAlias
import torch
from torch import nn
@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] = "image_embeds"
image_embeds: Annotated[
torch.Tensor,
TensorShape("ni", "nf", "hs"),
]
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
image_width: int | None,
image_height: int | None,
num_crops: int | None = None,
processor: Gemma3Processor | None,
) -> PromptUpdateDetails[str]:
if processor is None:
@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
boi_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops is None:
assert image_width is not None and image_height is not None
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = boi_token
@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
# For image embedding inputs, only support no crops cases
# since it's not supported in hf processor anyway
return self.info.get_image_repl(
image_width=None,
image_height=None,
num_crops=0,
processor=hf_processor,
)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
if pixel_values is not None:
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
elif image_embeds is not None:
return Gemma3ImageEmbeddingInputs(
image_embeds=image_embeds,
type="image_embeds",
)
def _image_pixels_to_features(
self,
@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> list[torch.Tensor]:
) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"]
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]

View File

@ -359,8 +359,9 @@ class MultiModalDataParser:
)
self.video_needs_metadata = video_needs_metadata
def _is_embeddings(
self, data: object
@classmethod
def is_embeddings(
cls, data: object
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
@ -420,7 +421,7 @@ class MultiModalDataParser:
):
return None
if self._is_embeddings(data):
if self.is_embeddings(data):
return AudioEmbeddingItems(data)
data_items: list[AudioItem]
@ -458,7 +459,7 @@ class MultiModalDataParser:
if self._is_empty(data):
return None
if self._is_embeddings(data):
if self.is_embeddings(data):
return ImageEmbeddingItems(data)
if (
@ -484,7 +485,7 @@ class MultiModalDataParser:
if self._is_empty(data):
return None
if self._is_embeddings(data):
if self.is_embeddings(data):
return VideoEmbeddingItems(data)
data_items: list[VideoItem]

View File

@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
@ -340,7 +341,12 @@ class Processor:
mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1
# Hash each item for embedding inputs.
n = (
len(data)
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
else 1
)
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids