[Model] Add Gemma3 GGUF multimodal support (#27772)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Luciano Martins 2025-11-18 13:56:29 -03:00 committed by GitHub
parent 49a986ecd4
commit c2612371ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 752 additions and 86 deletions

View File

@ -30,7 +30,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.13.0
gguf >= 0.17.0
mistral_common[image] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml

View File

@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, NamedTuple
import pytest
from huggingface_hub import hf_hub_download
from pytest import MarkDecorator
from tests.quantization.utils import is_quant_method_supported
from vllm.assets.image import ImageAsset
from vllm.utils.torch_utils import set_default_torch_num_threads
from ....conftest import PromptImageInput, VllmRunner
from ...utils import check_logprobs_close
class GGUFMMTestConfig(NamedTuple):
original_model: str
gguf_repo: str
gguf_backbone: str
gguf_mmproj: str
prompt: list[str]
mm_data: dict[Literal["images"], PromptImageInput]
max_model_len: int = 4096
marks: list[MarkDecorator] = []
@property
def gguf_model(self):
hf_hub_download(self.gguf_repo, filename=self.gguf_mmproj)
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
GEMMA3_CONFIG = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
gguf_mmproj="mmproj-model-f16-4B.gguf",
prompt=["<start_of_image>Describe this image in detail:"],
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
marks=[pytest.mark.core_model],
)
MODELS_TO_TEST = [GEMMA3_CONFIG]
def run_multimodal_gguf_test(
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
# Run gguf model.
with (
set_default_torch_num_threads(1),
vllm_runner(
model_name=model.gguf_model,
enforce_eager=True,
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
) as gguf_model,
):
gguf_outputs = gguf_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
# Run unquantized model.
with vllm_runner(
model_name=model.original_model,
enforce_eager=True, # faster tests
dtype=dtype,
max_model_len=model.max_model_len,
) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
check_logprobs_close(
outputs_0_lst=original_outputs,
outputs_1_lst=gguf_outputs,
name_0="original",
name_1="gguf",
)
@pytest.mark.skipif(
not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.",
)
@pytest.mark.parametrize(
"model",
[
pytest.param(test_config, marks=test_config.marks)
for test_config in MODELS_TO_TEST
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)

View File

@ -78,6 +78,12 @@ DOLPHIN_CONFIG = GGUFTestConfig(
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)
GEMMA3_CONFIG = GGUFTestConfig(
original_model="google/gemma-3-270m-it",
gguf_repo="ggml-org/gemma-3-270m-it-qat-GGUF",
gguf_filename="gemma-3-270m-it-qat-Q4_0.gguf",
)
MODELS = [
# LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458
QWEN2_CONFIG,
@ -85,6 +91,7 @@ MODELS = [
GPT2_CONFIG,
STABLELM_CONFIG,
DOLPHIN_CONFIG,
GEMMA3_CONFIG,
# STARCODER_CONFIG, # broken
]
@ -148,7 +155,7 @@ def check_model_outputs(
"model",
[pytest.param(test_config, marks=test_config.marks) for test_config in MODELS],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1])

View File

@ -33,10 +33,14 @@ from vllm.transformers_utils.config import (
try_get_generation_config,
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_custom_attention_masks,
uses_mrope,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
@ -450,6 +454,12 @@ class ModelConfig:
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
if check_gguf_file(self.model):
raise ValueError(
"Using a tokenizer is mandatory when loading a GGUF model. "
"Please specify the tokenizer path or name using the "
"--tokenizer argument."
)
self.tokenizer = self.model
if self.tokenizer_revision is None:
self.tokenizer_revision = self.revision
@ -508,6 +518,10 @@ class ModelConfig:
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn,
)
hf_config = maybe_patch_hf_config_from_gguf(
self.model,
hf_config,
)
self.hf_config = hf_config
if dict_overrides:
@ -1605,6 +1619,10 @@ class ModelConfig:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
@property
def uses_custom_attention_masks(self) -> bool:
return uses_custom_attention_masks(self.hf_config)
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from collections.abc import Callable, Mapping
from types import MappingProxyType
from typing import Any, Optional
import gguf
@ -26,7 +27,11 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
@ -65,18 +70,70 @@ class GGUFConfig(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping
):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping
):
return UnquantizedEmbeddingMethod()
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self, layer.moe_config)
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
return any(module_name in prefix for module_name in unquantized_modules)
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
if self.unquantized_modules is not None:
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
self.unquantized_modules
)
def is_layer_skipped_gguf(
prefix: str,
unquantized_modules: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
):
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = any(
shard_prefix in module_name for module_name in unquantized_modules
)
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = any(module_name in prefix for module_name in unquantized_modules)
assert is_skipped is not None
return is_skipped
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}

View File

@ -7,10 +7,11 @@ import gguf
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
@ -21,8 +22,11 @@ from vllm.model_executor.model_loader.weight_utils import (
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
)
from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class GGUFModelLoader(BaseModelLoader):
"""
@ -67,7 +71,15 @@ class GGUFModelLoader(BaseModelLoader):
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
# Get text config to handle both nested (multimodal) and flat
# (text-only) config structures. For multimodal models like
# Gemma3Config, this returns config.text_config. For text-only
# models, this returns config itself.
text_config = config.get_text_config()
model_type = config.model_type
is_multimodal = (
hasattr(config, "vision_config") and config.vision_config is not None
)
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
@ -115,24 +127,167 @@ class GGUFModelLoader(BaseModelLoader):
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
text_num_layers = text_config.num_hidden_layers
text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)
if is_multimodal:
mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
vision_num_layers = config.vision_config.num_hidden_layers
vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
else:
vision_name_map = None
# Create dummy model to extract parameter names
# For multimodal: use AutoModelForImageTextToText to get
# language + vision + projector params
# For text-only: use AutoModelForCausalLM to get language model params
auto_cls = (
AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(
dummy_model = auto_cls.from_config(
config, trust_remote_code=model_config.trust_remote_code
)
state_dict = dummy_model.state_dict()
state_dict = dummy_model.state_dict()
if hf_checkpoint_map := getattr(
dummy_model, "_checkpoint_conversion_mapping", None
):
def revert_hf_rename(name: str) -> str:
for original_name, hf_name in hf_checkpoint_map.items():
if hf_name in name:
name = name.replace(hf_name, original_name).lstrip("^")
return name
state_dict = {
revert_hf_rename(name): tensor for name, tensor in state_dict.items()
}
def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
"""
Map HuggingFace parameter name to GGUF tensor name.
This function handles the mismatch between HF parameter naming
conventions and gguf-py's expected format:
1. Strips 'model.' prefix (common in multimodal models)
2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
3. Searches vision_name_map for multimodal parameters
4. Falls back to text_name_map for language model parameters
Args:
hf_name: Full HuggingFace parameter name (e.g.,
'model.multi_modal_projector.mm_soft_emb_norm.weight')
Returns:
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
or None if no mapping found
"""
# Strip 'language_model.' prefix for multimodal models - gguf-py
# tensor mappings expect parameter names without this prefix.
# Note: 'model.' prefix should be KEPT for text-only models as
# gguf-py expects it.
if hf_name.startswith("language_model."):
hf_name = hf_name[15:] # Remove 'language_model.'
# Parse parameter name and suffix
if hf_name.endswith((".weight", ".bias")):
base_name, suffix = hf_name.rsplit(".", 1)
else:
base_name, suffix = hf_name, ""
# Handle '_weight' suffix (Gemma3 naming: parameter ends with
# '_weight' instead of '.weight')
if base_name.endswith("_weight"):
base_name = base_name[:-7] # Remove '_weight'
suffix = "weight"
gguf_name = None
# Priority 1: Search vision/projector parameters for multimodal models
if vision_name_map is not None:
gguf_name = vision_name_map.get_name(base_name)
# Priority 2: Search text backbone parameters
if gguf_name is None:
gguf_name = text_name_map.get_name(base_name)
if gguf_name is None:
return None
return gguf_name + "." + suffix
# Build mapping and track unmapped parameters
unmapped_params = []
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)
# Track mapping success
if gguf_name_with_suffix is not None:
gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
elif hf_name not in gguf_to_hf_name_map.values():
# Parameter not in manual overrides either
unmapped_params.append(hf_name)
# All parameters must be mapped: both vision/projector and backbone
if unmapped_params:
raise RuntimeError(
f"Failed to map GGUF parameters "
f"({len(unmapped_params)}): "
f"{unmapped_params}"
)
return gguf_to_hf_name_map
def _get_gguf_weight_type(
self,
model_config: ModelConfig,
model_name_or_path: str,
gguf_to_hf_name_map: dict[str, str],
) -> dict[str, str]:
weight_type_map = get_gguf_weight_type_map(
model_config.model, gguf_to_hf_name_map
)
is_multimodal = hasattr(model_config.hf_config, "vision_config")
if is_multimodal:
mmproj_file = detect_gguf_multimodal(model_name_or_path)
assert mmproj_file is not None, (
"Could not find mm_proj file for multimodal GGUF model"
)
logger.info("Loading extra mm_proj weights from %s...", mmproj_file)
mm_proj_weight_type_map = get_gguf_weight_type_map(
mmproj_file, gguf_to_hf_name_map
)
weight_type_map.update(mm_proj_weight_type_map)
return weight_type_map
def _get_weights_iterator(
self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
self,
model_config: ModelConfig,
model_name_or_path: str,
gguf_to_hf_name_map: dict[str, str],
) -> Generator[tuple[str, torch.Tensor], None, None]:
return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
"""
Iterate over GGUF model weights, loading from both main model file and
mmproj.gguf for multimodal Gemma3 models.
For Gemma3 multimodal GGUF models:
- Main file (gemma-3-*.gguf): Language model weights (model.*)
- mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)
Yields:
Tuples of (parameter_name, tensor) for all model weights
"""
hf_config = model_config.hf_config
is_multimodal = hasattr(hf_config, "vision_config")
if is_multimodal:
# Load mm_proj (mm_encoder + projector) for multimodal weights
mmproj_file = detect_gguf_multimodal(model_name_or_path)
assert mmproj_file is not None, (
"Could not find mm_proj file for multimodal GGUF model"
)
yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
@ -141,7 +296,7 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map)
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
)
def load_model(
@ -156,14 +311,19 @@ class GGUFModelLoader(BaseModelLoader):
):
model_config.hf_config.update({"tie_word_embeddings": True})
weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map)
weight_type_map = self._get_gguf_weight_type(
model_config, local_model_path, gguf_weights_map
)
# filter out unquantized modules to skip
unquant_names = [
name.removesuffix(".weight")
for name, weight_type in weight_type_map.items()
if weight_type == "F32" and name.endswith(".weight")
if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
]
logger.debug(
"GGUF unquantized modules: %s",
unquant_names,
)
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
target_device = torch.device(device_config.device)

View File

@ -836,7 +836,11 @@ def gguf_quant_weights_iterator(
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
them to torch tensors.
Be careful of the order of yielding weight types and weights data,
we have to yield all weight types first before yielding any weights.
Otherwise it would cause issue when loading weights with for packed
layer with different quant types.
"""
reader = gguf.GGUFReader(gguf_file)
@ -846,7 +850,7 @@ def gguf_quant_weights_iterator(
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
if weight_type.name not in ("F32", "BF16", "F16"):
weight_type_name = name.replace("weight", "qweight_type")
weight_type = torch.tensor(weight_type)
yield weight_type_name, weight_type
@ -856,7 +860,7 @@ def gguf_quant_weights_iterator(
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
if weight_type.name not in ("F32", "BF16", "F16"):
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
yield name, param

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, Literal
import torch
from torch import nn
@ -20,12 +20,7 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
@ -76,15 +71,7 @@ class Gemma3ImagePixelInputs(TensorSchema):
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class Gemma3ImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] = "image_embeds"
image_embeds: Annotated[
torch.Tensor,
TensorShape("ni", "nf", "hs"),
]
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
@ -191,9 +178,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_image_repl(
self,
*,
image_width: int | None,
image_height: int | None,
num_crops: int | None = None,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
) -> PromptUpdateDetails[str]:
if processor is None:
@ -201,13 +187,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
boi_token = processor.boi_token
if num_crops is None:
assert image_width is not None and image_height is not None
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = boi_token
@ -337,7 +321,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -350,19 +333,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
# For image embedding inputs, only support no crops cases
# since it's not supported in hf processor anyway
return self.info.get_image_repl(
image_width=None,
image_height=None,
num_crops=0,
processor=hf_processor,
)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
@ -586,19 +557,17 @@ class Gemma3ForConditionalGeneration(
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
if pixel_values is not None:
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
elif image_embeds is not None:
return Gemma3ImageEmbeddingInputs(
image_embeds=image_embeds,
type="image_embeds",
)
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
def _image_pixels_to_features(
self,
@ -610,9 +579,7 @@ class Gemma3ForConditionalGeneration(
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"]
) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
@ -629,13 +596,33 @@ class Gemma3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# Early return for text-only inference (no multimodal data)
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
# Use interface default with OOV handling enabled
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
input_ids: torch.Tensor,
@ -657,6 +644,79 @@ class Gemma3ForConditionalGeneration(
return hidden_states
def generate_attention_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
) -> dict[str, Any]:
"""Generate custom attention masks for Gemma3 multimodal inputs.
This is called by V1 engine's gpu_model_runner during preprocessing
to generate attention masks that allow bidirectional attention between
image tokens while maintaining causal attention for text.
"""
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i]
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
seq_lens.append(end_idx - start_idx)
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_idx, seq_len in enumerate(seq_lens):
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
# Find image token positions
img_pos = input_token_ids == self.config.image_token_index
start_idx = end_idx
# Create a global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0 (causal attention)
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Enable bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# GGUF compatibility: config might be Gemma3TextConfig directly
text_config = getattr(self.config, "text_config", self.config)
sliding_window = text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024)
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
return {
"has_images": True,
"seq_lens": seq_lens,
"global_attn_masks": global_attn_masks,
"local_attn_masks": local_attn_masks,
}
def prepare_attn_masks(
self,
input_ids: torch.Tensor,

View File

@ -827,6 +827,7 @@ class SiglipVisionModel(nn.Module):
) -> None:
super().__init__()
self.quant_config = quant_config
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
@ -911,12 +912,38 @@ class SiglipVisionModel(nn.Module):
break
else:
param = params_dict[name]
param = maybe_swap_ffn_param(
name, param, loaded_weight, params_dict, self.quant_config
)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def maybe_swap_ffn_param(
name: str,
param: torch.Tensor,
loaded_weight: torch.Tensor,
params_dict: dict[str, torch.Tensor],
quant_config: QuantizationConfig,
) -> torch.Tensor:
if not (quant_config and quant_config.get_name() == "gguf") or ".fc" not in name:
return param
# Some GGUF models have fc1 and fc2 weights swapped
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", 0)
output_size = param.size(output_dim) * tp_size
weight_out_size = loaded_weight.size(output_dim)
if ".fc1." in name and output_size != weight_out_size:
new_name = name.replace(".fc1.", ".fc2.")
param = params_dict[new_name]
elif ".fc2." in name and output_size != weight_out_size:
new_name = name.replace(".fc2.", ".fc1.")
param = params_dict[new_name]
return param
# Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):

View File

@ -477,6 +477,17 @@ def is_interleaved(config: PretrainedConfig) -> bool:
return False
def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
"""Detect if model uses custom attention mask generation for multimodal.
Some multimodal models require custom attention masks that enable
bidirectional attention between image tokens while maintaining causal
attention for text tokens. Currently applies to Gemma3 multimodal models.
"""
architectures = getattr(config, "architectures", [])
return "Gemma3ForConditionalGeneration" in architectures
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
"""
Update kwargs for AutoConfig initialization based on model_type

View File

@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""GGUF utility functions."""
from pathlib import Path
import gguf
from gguf.constants import Keys, VisionProjectorType
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
def detect_gguf_multimodal(model: str) -> Path | None:
"""Check if GGUF model has multimodal projector file.
Args:
model: Model path string
Returns:
Path to mmproj file if found, None otherwise
"""
if not model.endswith(".gguf"):
return None
try:
model_path = Path(model)
if not model_path.is_file():
return None
model_dir = model_path.parent
mmproj_patterns = ["mmproj.gguf", "mmproj-*.gguf", "*mmproj*.gguf"]
for pattern in mmproj_patterns:
mmproj_files = list(model_dir.glob(pattern))
if mmproj_files:
return mmproj_files[0]
return None
except Exception:
return None
def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | None":
"""Extract vision config parameters from mmproj.gguf metadata.
Reads vision encoder configuration from GGUF metadata fields using
standardized GGUF constants. Automatically detects the projector type
(e.g., gemma3, llama4) and applies model-specific parameters accordingly.
The function extracts standard CLIP vision parameters from GGUF metadata
and applies projector-type-specific customizations. For unknown projector
types, it uses safe defaults from SiglipVisionConfig.
Args:
mmproj_path: Path to mmproj.gguf file (str or Path)
Returns:
SiglipVisionConfig if extraction succeeds, None if any required
field is missing from the GGUF metadata
Raises:
Exception: Exceptions from GGUF reading (file not found, corrupted
file, etc.) propagate directly from gguf.GGUFReader
"""
reader = gguf.GGUFReader(str(mmproj_path))
# Detect projector type to apply model-specific parameters
projector_type = None
projector_type_field = reader.get_field(Keys.Clip.PROJECTOR_TYPE)
if projector_type_field:
try:
projector_type = bytes(projector_type_field.parts[-1]).decode("utf-8")
except (AttributeError, UnicodeDecodeError) as e:
logger.warning("Failed to decode projector type from GGUF: %s", e)
# Map GGUF field constants to SiglipVisionConfig parameters.
# Uses official GGUF constants from gguf-py for standardization.
# Format: {gguf_constant: (param_name, dtype)}
VISION_CONFIG_FIELDS = {
Keys.ClipVision.EMBEDDING_LENGTH: ("hidden_size", int),
Keys.ClipVision.FEED_FORWARD_LENGTH: ("intermediate_size", int),
Keys.ClipVision.BLOCK_COUNT: ("num_hidden_layers", int),
Keys.ClipVision.Attention.HEAD_COUNT: ("num_attention_heads", int),
Keys.ClipVision.IMAGE_SIZE: ("image_size", int),
Keys.ClipVision.PATCH_SIZE: ("patch_size", int),
Keys.ClipVision.Attention.LAYERNORM_EPS: ("layer_norm_eps", float),
}
# Extract and validate all required fields
config_params = {}
for gguf_key, (param_name, dtype) in VISION_CONFIG_FIELDS.items():
field = reader.get_field(gguf_key)
if field is None:
logger.warning(
"Missing required vision config field '%s' in mmproj.gguf",
gguf_key,
)
return None
# Extract scalar value from GGUF field and convert to target type
config_params[param_name] = dtype(field.parts[-1])
# Apply model-specific parameters based on projector type
if projector_type == VisionProjectorType.GEMMA3:
# Gemma3 doesn't use the vision pooling head (multihead attention)
# This is a vLLM-specific parameter used in SiglipVisionTransformer
config_params["vision_use_head"] = False
logger.info("Detected Gemma3 projector, disabling vision pooling head")
# Add other projector-type-specific customizations here as needed
# elif projector_type == VisionProjectorType.LLAMA4:
# config_params["vision_use_head"] = ...
# Create config with extracted parameters
# Note: num_channels and attention_dropout use SiglipVisionConfig defaults
# (3 and 0.0 respectively) which are correct for all models
config = SiglipVisionConfig(**config_params)
if projector_type:
logger.info(
"Extracted vision config from mmproj.gguf (projector_type: %s)",
projector_type,
)
else:
logger.info("Extracted vision config from mmproj.gguf metadata")
return config
def maybe_patch_hf_config_from_gguf(
model: str,
hf_config: PretrainedConfig,
) -> PretrainedConfig:
"""Patch HF config for GGUF models.
Applies GGUF-specific patches to HuggingFace config:
1. For multimodal models: patches architecture and vision config
2. For all GGUF models: overrides vocab_size from embedding tensor
This ensures compatibility with GGUF models that have extended
vocabularies (e.g., Unsloth) where the GGUF file contains more
tokens than the HuggingFace tokenizer config specifies.
Args:
model: Model path string
hf_config: HuggingFace config to patch in-place
Returns:
Updated HuggingFace config
"""
# Patch multimodal config if mmproj.gguf exists
mmproj_path = detect_gguf_multimodal(model)
if mmproj_path is not None:
vision_config = extract_vision_config_from_gguf(str(mmproj_path))
# Create HF config for Gemma3 multimodal
text_config = hf_config.get_text_config()
is_gemma3 = hf_config.model_type in ("gemma3", "gemma3_text")
if vision_config is not None and is_gemma3:
new_hf_config = Gemma3Config.from_text_vision_configs(
text_config=text_config,
vision_config=vision_config,
architectures=["Gemma3ForConditionalGeneration"],
)
hf_config = new_hf_config
return hf_config

View File

@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
from vllm.transformers_utils.utils import convert_model_repo_to_path
from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
@ -236,9 +236,20 @@ def cached_processor_from_config(
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
if check_gguf_file(model_config.model):
assert not check_gguf_file(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor."
)
model = model_config.tokenizer
revision = model_config.tokenizer_revision
else:
model = model_config.model
revision = model_config.revision
return cached_get_processor_without_dynamic_kwargs(
model_config.model,
revision=model_config.revision,
model,
revision=revision,
trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, processor_cls, **kwargs),
@ -339,9 +350,19 @@ def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
if check_gguf_file(model_config.model):
assert not check_gguf_file(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor."
)
model = model_config.tokenizer
revision = model_config.tokenizer_revision
else:
model = model_config.model
revision = model_config.revision
return cached_get_image_processor(
model_config.model,
revision=model_config.revision,
model,
revision=revision,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
)

View File

@ -27,6 +27,7 @@ def is_cloud_storage(model_or_path: str) -> bool:
return is_s3(model_or_path) or is_gcs(model_or_path)
@cache
def check_gguf_file(model: str | PathLike) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)

View File

@ -324,6 +324,7 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@ -2346,6 +2347,24 @@ class GPUModelRunner(
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
# Generate custom attention masks for models that require them.
# V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
# Check mm_features (mm_embeds is empty during decode).
has_mm_features = any(
req_state.mm_features for req_state in self.requests.values()
)
if (
self.uses_custom_attention_masks
and has_mm_features
and hasattr(self.model, "generate_attention_masks")
):
mask_kwargs = self.model.generate_attention_masks(
self.input_ids.gpu[:num_scheduled_tokens],
self.positions.gpu[:num_scheduled_tokens],
mask_dtype=self.model.dtype,
)
model_kwargs.update(mask_kwargs)
elif self.enable_prompt_embeds and is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.