mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:15:00 +08:00
Allow Gemma3 to take image embeddings (#28483)
Signed-off-by: tingtinggithub <streamttt@gmail.com>
This commit is contained in:
parent
f36292dbee
commit
cb15ee28db
@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
|
||||
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
class Gemma3ImageEmbeddingInputs(TensorSchema):
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", "nf", "hs"),
|
||||
]
|
||||
|
||||
|
||||
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
|
||||
|
||||
|
||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
def get_image_repl(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_width: int | None,
|
||||
image_height: int | None,
|
||||
num_crops: int | None = None,
|
||||
processor: Gemma3Processor | None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
if processor is None:
|
||||
@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
boi_token = processor.boi_token
|
||||
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
if num_crops is None:
|
||||
assert image_width is not None and image_height is not None
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if num_crops == 0:
|
||||
image_text = boi_token
|
||||
@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
image_token = hf_processor.boi_token
|
||||
|
||||
def get_replacement_gemma3(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
# For image embedding inputs, only support no crops cases
|
||||
# since it's not supported in hf processor anyway
|
||||
return self.info.get_image_repl(
|
||||
image_width=None,
|
||||
image_height=None,
|
||||
num_crops=0,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
return self.info.get_image_repl(
|
||||
@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_patches = kwargs.pop("num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
image_size = self.config.vision_config.image_size
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={"h": image_size, "w": image_size},
|
||||
)
|
||||
if pixel_values is not None:
|
||||
image_size = self.config.vision_config.image_size
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={"h": image_size, "w": image_size},
|
||||
)
|
||||
elif image_embeds is not None:
|
||||
return Gemma3ImageEmbeddingInputs(
|
||||
image_embeds=image_embeds,
|
||||
type="image_embeds",
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma3ImageInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["image_embeds"]
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
|
||||
@ -359,8 +359,9 @@ class MultiModalDataParser:
|
||||
)
|
||||
self.video_needs_metadata = video_needs_metadata
|
||||
|
||||
def _is_embeddings(
|
||||
self, data: object
|
||||
@classmethod
|
||||
def is_embeddings(
|
||||
cls, data: object
|
||||
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
@ -420,7 +421,7 @@ class MultiModalDataParser:
|
||||
):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
if self.is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
|
||||
data_items: list[AudioItem]
|
||||
@ -458,7 +459,7 @@ class MultiModalDataParser:
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
if self.is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (
|
||||
@ -484,7 +485,7 @@ class MultiModalDataParser:
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
if self.is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
data_items: list[VideoItem]
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -340,7 +341,12 @@ class Processor:
|
||||
|
||||
mm_uuids: dict[str, list[str | None] | str] = {}
|
||||
for modality, data in mm_data.items():
|
||||
n = len(data) if isinstance(data, list) else 1
|
||||
# Hash each item for embedding inputs.
|
||||
n = (
|
||||
len(data)
|
||||
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
|
||||
else 1
|
||||
)
|
||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||
return mm_uuids
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user