mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-30 23:10:06 +08:00
Allow Gemma3 to take image embeddings (#28483)
Signed-off-by: tingtinggithub <streamttt@gmail.com>
This commit is contained in:
parent
f36292dbee
commit
cb15ee28db
@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
|||||||
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
|
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
|
||||||
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
|
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
|
||||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
|
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
|
||||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
|
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
|
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
|
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal, TypeAlias
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
|
|||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
MultiModalKwargsItems,
|
MultiModalKwargsItems,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
from vllm.multimodal.parse import (
|
||||||
|
ImageEmbeddingItems,
|
||||||
|
ImageProcessorItems,
|
||||||
|
ImageSize,
|
||||||
|
MultiModalDataItems,
|
||||||
|
)
|
||||||
from vllm.multimodal.processing import (
|
from vllm.multimodal.processing import (
|
||||||
BaseMultiModalProcessor,
|
BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo,
|
BaseProcessingInfo,
|
||||||
@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
|
|||||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
class Gemma3ImageEmbeddingInputs(TensorSchema):
|
||||||
|
type: Literal["image_embeds"] = "image_embeds"
|
||||||
|
image_embeds: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("ni", "nf", "hs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
|
||||||
|
|
||||||
|
|
||||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||||
@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
def get_image_repl(
|
def get_image_repl(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int | None,
|
||||||
image_height: int,
|
image_height: int | None,
|
||||||
|
num_crops: int | None = None,
|
||||||
processor: Gemma3Processor | None,
|
processor: Gemma3Processor | None,
|
||||||
) -> PromptUpdateDetails[str]:
|
) -> PromptUpdateDetails[str]:
|
||||||
if processor is None:
|
if processor is None:
|
||||||
@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
|
|
||||||
boi_token = processor.boi_token
|
boi_token = processor.boi_token
|
||||||
|
|
||||||
num_crops = self.get_num_crops(
|
if num_crops is None:
|
||||||
image_width=image_width,
|
assert image_width is not None and image_height is not None
|
||||||
image_height=image_height,
|
num_crops = self.get_num_crops(
|
||||||
processor=processor,
|
image_width=image_width,
|
||||||
)
|
image_height=image_height,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
if num_crops == 0:
|
if num_crops == 0:
|
||||||
image_text = boi_token
|
image_text = boi_token
|
||||||
@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
|
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
|
||||||
num_patches=MultiModalFieldConfig.batched("image"),
|
num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
image_token = hf_processor.boi_token
|
image_token = hf_processor.boi_token
|
||||||
|
|
||||||
def get_replacement_gemma3(item_idx: int):
|
def get_replacement_gemma3(item_idx: int):
|
||||||
images = mm_items.get_items("image", ImageProcessorItems)
|
images = mm_items.get_items(
|
||||||
|
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(images, ImageEmbeddingItems):
|
||||||
|
# For image embedding inputs, only support no crops cases
|
||||||
|
# since it's not supported in hf processor anyway
|
||||||
|
return self.info.get_image_repl(
|
||||||
|
image_width=None,
|
||||||
|
image_height=None,
|
||||||
|
num_crops=0,
|
||||||
|
processor=hf_processor,
|
||||||
|
)
|
||||||
|
|
||||||
image_size = images.get_image_size(item_idx)
|
image_size = images.get_image_size(item_idx)
|
||||||
return self.info.get_image_repl(
|
return self.info.get_image_repl(
|
||||||
@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
|
|||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
num_patches = kwargs.pop("num_patches", None)
|
num_patches = kwargs.pop("num_patches", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
|
||||||
if pixel_values is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
image_size = self.config.vision_config.image_size
|
if pixel_values is not None:
|
||||||
|
image_size = self.config.vision_config.image_size
|
||||||
return Gemma3ImagePixelInputs(
|
return Gemma3ImagePixelInputs(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
num_patches=num_patches,
|
num_patches=num_patches,
|
||||||
resolve_bindings={"h": image_size, "w": image_size},
|
resolve_bindings={"h": image_size, "w": image_size},
|
||||||
)
|
)
|
||||||
|
elif image_embeds is not None:
|
||||||
|
return Gemma3ImageEmbeddingInputs(
|
||||||
|
image_embeds=image_embeds,
|
||||||
|
type="image_embeds",
|
||||||
|
)
|
||||||
|
|
||||||
def _image_pixels_to_features(
|
def _image_pixels_to_features(
|
||||||
self,
|
self,
|
||||||
@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
|
|||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self,
|
self,
|
||||||
image_input: Gemma3ImageInputs,
|
image_input: Gemma3ImageInputs,
|
||||||
) -> list[torch.Tensor]:
|
) -> torch.Tensor | list[torch.Tensor]:
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["image_embeds"]
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
pixel_values = image_input["pixel_values"]
|
pixel_values = image_input["pixel_values"]
|
||||||
|
|||||||
@ -359,8 +359,9 @@ class MultiModalDataParser:
|
|||||||
)
|
)
|
||||||
self.video_needs_metadata = video_needs_metadata
|
self.video_needs_metadata = video_needs_metadata
|
||||||
|
|
||||||
def _is_embeddings(
|
@classmethod
|
||||||
self, data: object
|
def is_embeddings(
|
||||||
|
cls, data: object
|
||||||
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
|
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.ndim == 3
|
return data.ndim == 3
|
||||||
@ -420,7 +421,7 @@ class MultiModalDataParser:
|
|||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._is_embeddings(data):
|
if self.is_embeddings(data):
|
||||||
return AudioEmbeddingItems(data)
|
return AudioEmbeddingItems(data)
|
||||||
|
|
||||||
data_items: list[AudioItem]
|
data_items: list[AudioItem]
|
||||||
@ -458,7 +459,7 @@ class MultiModalDataParser:
|
|||||||
if self._is_empty(data):
|
if self._is_empty(data):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._is_embeddings(data):
|
if self.is_embeddings(data):
|
||||||
return ImageEmbeddingItems(data)
|
return ImageEmbeddingItems(data)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -484,7 +485,7 @@ class MultiModalDataParser:
|
|||||||
if self._is_empty(data):
|
if self._is_empty(data):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._is_embeddings(data):
|
if self.is_embeddings(data):
|
||||||
return VideoEmbeddingItems(data)
|
return VideoEmbeddingItems(data)
|
||||||
|
|
||||||
data_items: list[VideoItem]
|
data_items: list[VideoItem]
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.multimodal.cache import processor_cache_from_config
|
from vllm.multimodal.cache import processor_cache_from_config
|
||||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||||
|
from vllm.multimodal.parse import MultiModalDataParser
|
||||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||||
from vllm.multimodal.utils import argsort_mm_positions
|
from vllm.multimodal.utils import argsort_mm_positions
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
@ -340,7 +341,12 @@ class Processor:
|
|||||||
|
|
||||||
mm_uuids: dict[str, list[str | None] | str] = {}
|
mm_uuids: dict[str, list[str | None] | str] = {}
|
||||||
for modality, data in mm_data.items():
|
for modality, data in mm_data.items():
|
||||||
n = len(data) if isinstance(data, list) else 1
|
# Hash each item for embedding inputs.
|
||||||
|
n = (
|
||||||
|
len(data)
|
||||||
|
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
|
||||||
|
else 1
|
||||||
|
)
|
||||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||||
return mm_uuids
|
return mm_uuids
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user