[Model] Always use Transformers backend for PaliGemma and Gemma3-MM (#26715)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-17 13:03:35 +08:00 committed by GitHub
parent 9c2c2287a0
commit 8c017b3490
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 54 additions and 1219 deletions

View File

@ -16,8 +16,8 @@
| meta-llama/Llama-4-* | Llama4ForConditionalGeneration | ❌ |
| microsoft/Phi-3-mini-128k-instruct | Phi3ForCausalLM | 🟨 |
| microsoft/phi-4 | Phi3ForCausalLM | ❌ |
| google/gemma-3-27b-it | Gemma3ForConditionalGeneration | 🟨 |
| google/gemma-3-4b-it | Gemma3ForConditionalGeneration | ❌ |
| google/gemma-3-27b-it | TransformersForMultimodalLM | 🟨 |
| google/gemma-3-4b-it | TransformersForMultimodalLM | ❌ |
| deepseek-ai/DeepSeek-R1 | DeepseekV3ForCausalLM | ❌ |
| deepseek-ai/DeepSeek-V3 | DeepseekV3ForCausalLM | ❌ |
| RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 | LlamaForCausalLM | ✅ |

View File

@ -650,7 +650,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
@ -679,7 +678,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ |
| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
@ -704,6 +702,8 @@ Some models are supported only via the [Transformers backend](#transformers). Th
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|
| `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ✅︎ |
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:
@ -712,21 +712,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
!!! warning
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:
V0 correctly implements the model's attention pattern:
- Uses bidirectional attention between the image tokens corresponding to the same image
- Uses causal attention for other tokens
- Implemented via (naive) PyTorch SDPA with masking tensors
- Note: May use significant memory for long prompts with image
V1 currently uses a simplified attention pattern:
- Uses causal attention for all tokens, including image tokens
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
- Will be updated in the future to support the correct behavior
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
For `Gemma3ForConditionalGeneration`, `{"do_pan_and_scan": true}` is not supported in Transformers backend yet.
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
@ -778,9 +764,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: <https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630>
!!! warning
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
!!! note
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.

View File

@ -248,7 +248,8 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=2048,
max_num_seqs=2,
mm_processor_kwargs={"do_pan_and_scan": True},
# TODO: Support this in transformers backend
# mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={modality: 1},
)

View File

@ -3,7 +3,7 @@
import numpy as np
import pytest
MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"]
MODELS = ["google/gemma-2b", "google/gemma-2-2b"]
@pytest.mark.parametrize("model", MODELS)
@ -14,14 +14,8 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
model,
load_format="dummy",
) as llm:
if model == "google/gemma-3-4b-it":
normalizers = llm.llm.collective_rpc(
lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item() # noqa: E501
)
config = llm.llm.llm_engine.model_config.hf_config.text_config
else:
normalizers = llm.llm.collective_rpc(
lambda self: self.model_runner.model.model.normalizer.cpu().item()
)
config = llm.llm.llm_engine.model_config.hf_config
normalizers = llm.apply_model(
lambda model: model.model.normalizer.cpu().item()
)
config = llm.llm.llm_engine.model_config.hf_config
assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3)

View File

@ -113,25 +113,6 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16" if current_platform.is_cpu() else "auto",
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"paligemma": VLMTestInfo(
models=["google/paligemma-3b-mix-224"],
test_type=VLMTestType.IMAGE,
prompt_formatter=identity,
img_idx_to_prompt=lambda idx: "",
# Paligemma uses its own sample prompts because the default one fails
single_image_prompts=IMAGE_ASSETS.prompts(
{
"stop_sign": "caption es",
"cherry_blossom": "What is in the picture?",
}
),
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
dtype="bfloat16",
marks=[
pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")
],
),
"qwen2_5_vl": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO),
@ -196,14 +177,24 @@ VLM_TEST_SETTINGS = {
# Gemma3 has bidirectional mask on images
"gemma3-transformers": VLMTestInfo(
models=["google/gemma-3-4b-it"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda vid_prompt: f"<'<bos><start_of_turn>user\n{vid_prompt}<start_of_image><end_of_turn>\n<start_of_turn>model\n", # noqa: E501
max_model_len=4096,
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts(
{
"stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<start_of_image>What is the season?",
}
),
multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.", # noqa: E501
max_model_len=8192,
auto_cls=AutoModelForImageTextToText,
# TODO: Support `do_pan_and_scan` in transformers backend
# patch_hf_runner=model_utils.gemma3_patch_hf_runner,
vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output,
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "transformers",
# "mm_processor_kwargs": {"do_pan_and_scan": True},
},
marks=[pytest.mark.core_model],
),
@ -222,6 +213,27 @@ VLM_TEST_SETTINGS = {
},
marks=[pytest.mark.core_model],
),
# PaliGemma has PrefixLM attention
"paligemma-transformers": VLMTestInfo(
models=["google/paligemma-3b-mix-224"],
test_type=VLMTestType.IMAGE,
prompt_formatter=identity,
img_idx_to_prompt=lambda idx: "",
# PaliGemma uses its own sample prompts because the default one fails
single_image_prompts=IMAGE_ASSETS.prompts(
{
"stop_sign": "caption es",
"cherry_blossom": "What is in the picture?",
}
),
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "transformers",
},
marks=[pytest.mark.core_model],
),
# Pixel values from processor are not 4D or 5D arrays
"qwen2_5_vl-transformers": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
@ -348,24 +360,6 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[large_gpu_mark(min_gb=32)],
),
"gemma3": VLMTestInfo(
models=["google/gemma-3-4b-it"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts(
{
"stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<start_of_image>What is the season?",
}
),
multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
num_logprobs=10,
),
"glm4v": VLMTestInfo(
models=["zai-org/glm-4v-9b"],
test_type=VLMTestType.IMAGE,

View File

@ -328,16 +328,6 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = processor
orig_generate = hf_model.model.generate
def _generate(self, *args, **kwargs):
# FIXME: https://github.com/huggingface/transformers/issues/38333
kwargs["disable_compile"] = True
return orig_generate(*args, **kwargs)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model

View File

@ -222,7 +222,6 @@ def _test_processing_correctness(
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"ovis": False,
"ovis2_5": False,
"paligemma": False,
"ultravox": False,
"whisper": False,
}
@ -333,7 +332,6 @@ def _test_processing_correctness_one(
"deepseek-ai/deepseek-vl2-tiny",
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"google/gemma-3n-E2B-it",
"zai-org/glm-4v-9b",
"zai-org/GLM-4.1V-9B-Thinking",
@ -370,8 +368,6 @@ def _test_processing_correctness_one(
"AIDC-AI/Ovis1.6-Llama3.2-3B",
"AIDC-AI/Ovis2-1B",
"AIDC-AI/Ovis2.5-2B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-3.5-vision-instruct",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409",

View File

@ -48,7 +48,6 @@ ARCH_NEEDS_EXTRAS = [
"Idefics3ForConditionalGeneration",
"LlavaForConditionalGeneration",
"MiniCPMV",
"PaliGemmaForConditionalGeneration",
]
REPO_ID_TO_SKIP = {
"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test",

View File

@ -1,710 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
replace_token_matches,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
class Gemma3ImagePixelInputs(TensorSchema):
"""
Dimensions:
- p: Number of patches total (over each image over each prompt in the
batch)
- c: Number of channels (3)
- h: Height of each patch
- w: Width of each patch
- bn: Batch size * number of images
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def _resolve_image_kwargs(
self,
processor: Gemma3Processor,
keys: set[str],
) -> dict[str, Any]:
image_processor = processor.image_processor
kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
)
images_kwargs = kwargs["images_kwargs"]
def _resolve_kw(key: str):
val = getattr(image_processor, key)
if val is None:
val = images_kwargs[key]
return val
return {k: _resolve_kw(k) for k in keys}
def get_num_crops(
self,
*,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor,
{
"do_pan_and_scan",
"pan_and_scan_min_crop_size",
"pan_and_scan_max_num_crops",
"pan_and_scan_min_ratio_to_activate",
},
)
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"
]
if not do_pan_and_scan:
return 0
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used."
)
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_w = min(
int(math.floor(image_width / pan_and_scan_min_crop_size)),
int(math.floor(image_width / image_height + 0.5)),
)
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
else:
if image_height / image_width < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_h = min(
int(math.floor(image_height / pan_and_scan_min_crop_size)),
int(math.floor(image_height / image_width + 0.5)),
)
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(image_width / num_crops_w))
crop_size_h = int(math.ceil(image_height / num_crops_h))
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return 0
return num_crops_w * num_crops_h
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
) -> PromptUpdateDetails[str]:
if processor is None:
processor = self.get_hf_processor()
boi_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = boi_token
else:
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
image_text = (
f"Here is the original image {boi_token} and here are some "
f"crops to help you see better {crops_image_tokens}"
)
repl_full = image_text.replace(boi_token, processor.full_image_sequence)
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_seq_len = processor.image_seq_length
return (num_crops + 1) * image_seq_len
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {"pan_and_scan_max_num_crops"}
)
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
# Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.boi_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
tok_kwargs,
)
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
parsed_images = (
self._get_data_parser()
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_crops = [
self.info.get_num_crops(
image_width=size.width,
image_height=size.height,
processor=hf_processor,
)
for size in image_sizes
]
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_gemma3,
)
]
def _apply_token_matches(
self,
prompt: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
# tokens, we have to combine them to be consistent with
# the output of the tokenizer
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
token_ids = replace_token_matches(
token_ids,
[newline_1, newline_2],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_1],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_2],
[newline_4],
)
return token_ids, res
def _find_mm_placeholders(
self,
new_token_ids: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
def get_repl_toks(tok: int) -> list[int]:
if tok == newline_3:
return [newline_1, newline_2]
if tok == newline_4:
return [newline_2, newline_2]
return [tok]
repl_token_ids = list[int]()
repl_orig_idxs = list[int]()
for orig_idx, orig_tok in enumerate(new_token_ids):
repl_toks = get_repl_toks(orig_tok)
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates)
return {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
is_embed=p.is_embed,
)
for p in placeholders
]
for modality, placeholders in repls.items()
}
class Gemma3MultiModalProjector(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(
config.vision_config.hidden_size, config.text_config.hidden_size
)
)
self.mm_soft_emb_norm = GemmaRMSNorm(
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
@MULTIMODAL_REGISTRY.register_processor(
Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder,
)
class Gemma3ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<start_of_image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma3ForCausalLM"],
)
logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(self.language_model, "logits_processor"):
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@property
def dtype(self):
return next(self.parameters()).dtype
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Gemma3ImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:
return vision_tower(pixel_values)
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
):
kwargs["has_images"] = True
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i].item()
if i < num_seqs - 1:
end_idx = start_indices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create a global causal mask.
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
sliding_window = self.config.text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower",
)

View File

@ -1,412 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, TypeAlias
import torch
from torch import nn
from transformers import BatchFeature, PaliGemmaConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalKwargsItems,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptIndexTargets,
PromptInsertion,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vision_encoder_info
logger = init_logger(__name__)
class PaliGemmaImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] = "pixel_values"
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class PaliGemmaImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
PaliGemmaImageInputs: TypeAlias = (
PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs
)
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states
class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(PaliGemmaConfig)
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": 1}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
vision_encoder_info = self.get_vision_encoder_info()
return vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=max_image_size,
height=max_image_size,
num_images=num_images,
overrides=image_overrides,
)
}
class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if not mm_data:
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
tokenizer = self.info.get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
def get_insertion(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = [image_token_id] * num_image_tokens
return PromptUpdateDetails.select_token_id(
image_tokens + [bos_token_id],
embed_token_id=image_token_id,
)
# Paligemma 1 and 2 have different tokenizer.add_bos_token
# Insert <image>*n + <bos> after <bos> for Paligemma 1
# Insert <image>*n + <bos> for Paligemma 2
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []
),
insertion=get_insertion,
)
]
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
mm_inputs = super().apply(
prompt,
mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer()
newline_prompt = "\n"
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
# Force to add newline at the end of prompt for paligemma's format
# This step can NOT be replacemented by current PromptUpdate methods
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
prompt_token_ids.append(newline_token_id)
mm_inputs["prompt_token_ids"] = prompt_token_ids
return mm_inputs
@MULTIMODAL_REGISTRY.register_processor(
PaliGemmaMultiModalProcessor,
info=PaliGemmaProcessingInfo,
dummy_inputs=PaliGemmaDummyInputsBuilder,
)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return None
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim,
)
self.quant_config = quant_config
if config.text_config.model_type == "gemma":
config.text_config.architectures = ["GemmaForCausalLM"]
else:
config.text_config.architectures = ["Gemma2ForCausalLM"]
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> PaliGemmaImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = flatten_bn(pixel_values, concat=True)
h = w = self.config.vision_config.image_size
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=pixel_values,
resolve_bindings={"h": h, "w": w},
)
if image_embeds is not None:
image_embeds = flatten_bn(image_embeds, concat=True)
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
return image_features
def _process_image_input(
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
pixel_values = image_input["data"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
return self.multi_modal_projector(image_features)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -263,7 +263,6 @@ _MULTIMODAL_MODELS = {
"Ernie4_5_VLMoeForConditionalGeneration",
),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": (
"gemma3n_mm",
"Gemma3nForConditionalGeneration",
@ -329,10 +328,6 @@ _MULTIMODAL_MODELS = {
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis": ("ovis", "Ovis"),
"Ovis2_5": ("ovis2_5", "Ovis2_5"),
"PaliGemmaForConditionalGeneration": (
"paligemma",
"PaliGemmaForConditionalGeneration",
),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501
@ -405,6 +400,14 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
"transformers",
"TransformersMultiModalForCausalLM",
),
"Gemma3ForConditionalGeneration": (
"transformers",
"TransformersMultiModalForCausalLM",
),
"PaliGemmaForConditionalGeneration": (
"transformers",
"TransformersMultiModalForCausalLM",
),
}
_TRANSFORMERS_BACKEND_MODELS = {

View File

@ -59,9 +59,6 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
"MistralForCausalLM": _ROCM_SWA_REASON,
"MixtralForCausalLM": _ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration": (
"ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
),
"Phi3VForCausalLM": (
"ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "