mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 01:55:48 +08:00
[Bugfix] Fix gemma3 with transformers backend (#23178)
Signed-off-by: raushan <raushan@huggingface.co> Signed-off-by: Raushan Turganbay <raushan@huggingface.co> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
c02058c222
commit
7cd95dc8a3
@ -193,6 +193,20 @@ VLM_TEST_SETTINGS = {
|
|||||||
# when processing the 3rd prompt in vLLM
|
# when processing the 3rd prompt in vLLM
|
||||||
marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")],
|
marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")],
|
||||||
),
|
),
|
||||||
|
# Gemma3 has bidirectional mask on images
|
||||||
|
"gemma3-transformers": VLMTestInfo(
|
||||||
|
models=["google/gemma-3-4b-it"],
|
||||||
|
test_type=VLMTestType.IMAGE,
|
||||||
|
prompt_formatter=lambda vid_prompt: f"<'<bos><start_of_turn>user\n{vid_prompt}<start_of_image><end_of_turn>\n<start_of_turn>model\n", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
auto_cls=AutoModelForImageTextToText,
|
||||||
|
vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output,
|
||||||
|
image_size_factors=[(0.25, 0.5, 1.0)],
|
||||||
|
vllm_runner_kwargs={
|
||||||
|
"model_impl": "transformers",
|
||||||
|
},
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
),
|
||||||
"idefics3-transformers": VLMTestInfo(
|
"idefics3-transformers": VLMTestInfo(
|
||||||
models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
|
models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
|
||||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
|||||||
@ -342,6 +342,29 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
|
def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput:
|
||||||
|
"""Sanitize vllm output [gemma-3] to compare with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
image_token_id = config.image_token_id
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
hf_output_ids = [
|
||||||
|
token_id
|
||||||
|
for idx, token_id in enumerate(output_ids)
|
||||||
|
if token_id != image_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_output_str = output_str
|
||||||
|
if hf_output_ids[-1] == eos_token_id:
|
||||||
|
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||||
|
|
||||||
|
return hf_output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
|
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
|
||||||
hf_processor = hf_model.processor
|
hf_processor = hf_model.processor
|
||||||
|
|||||||
@ -68,13 +68,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingIn
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant
|
||||||
MultiModalEmbeddings,
|
|
||||||
SupportsLoRA,
|
|
||||||
SupportsMultiModal,
|
|
||||||
SupportsPP,
|
|
||||||
SupportsQuant,
|
|
||||||
)
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@ -534,7 +528,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
self.attention_instances = self.create_attention_instances()
|
self.attention_instances = self.create_attention_instances()
|
||||||
|
|
||||||
# Input embeddings
|
# Input embeddings
|
||||||
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
input_embeddings = self.model.get_input_embeddings()
|
||||||
|
if not isinstance(input_embeddings, PPMissingLayer):
|
||||||
|
# Some models use embedding scales
|
||||||
|
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
|
||||||
names = ("embedding_size", "hidden_size")
|
names = ("embedding_size", "hidden_size")
|
||||||
embedding_dim = getattr_iter(self.text_config, names, None)
|
embedding_dim = getattr_iter(self.text_config, names, None)
|
||||||
assert embedding_dim is not None
|
assert embedding_dim is not None
|
||||||
@ -671,6 +668,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
|
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
|
||||||
head_size = self.model_config.get_head_size()
|
head_size = self.model_config.get_head_size()
|
||||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None)
|
||||||
start, end = get_pp_indices(
|
start, end = get_pp_indices(
|
||||||
self.text_config.num_hidden_layers,
|
self.text_config.num_hidden_layers,
|
||||||
self.pp_group.rank_in_group,
|
self.pp_group.rank_in_group,
|
||||||
@ -696,6 +694,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
per_layer_sliding_window=per_layer_sliding_window,
|
per_layer_sliding_window=per_layer_sliding_window,
|
||||||
prefix=f"{i}.attn",
|
prefix=f"{i}.attn",
|
||||||
attn_type=attn_type,
|
attn_type=attn_type,
|
||||||
@ -735,6 +734,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
if not self.pp_group.is_first_rank:
|
if not self.pp_group.is_first_rank:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
@ -758,6 +758,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
attention_instances=self.attention_instances,
|
attention_instances=self.attention_instances,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
|
**kwargs,
|
||||||
)[0][0, ...] # we remove batch dimension for now
|
)[0][0, ...] # we remove batch dimension for now
|
||||||
|
|
||||||
if not self.pp_group.is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
@ -819,7 +820,10 @@ class TransformersForCausalLM(TransformersBase):
|
|||||||
self.lm_head = PPMissingLayer()
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||||
|
if self.embed_scale is not None:
|
||||||
|
inputs_embeds *= self.embed_scale
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
@ -845,6 +849,7 @@ class TransformersForCausalLM(TransformersBase):
|
|||||||
enable_if=can_enable_torch_compile,
|
enable_if=can_enable_torch_compile,
|
||||||
)
|
)
|
||||||
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
||||||
|
supports_multimodal_raw_input_only = True
|
||||||
merge_by_field_config = True
|
merge_by_field_config = True
|
||||||
# Backwards compatibility for prev released models. State dicts back then
|
# Backwards compatibility for prev released models. State dicts back then
|
||||||
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
||||||
@ -883,13 +888,27 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
|
||||||
|
# Other models will not have `token_type_ids` in kwargs
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
|
||||||
model_output = super().forward(
|
model_output = super().forward(
|
||||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||||
)
|
)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.model
|
"""`TransformersForMultimodalLM` does not contain a vLLM language model class.
|
||||||
|
Therefore, in order to return a language model vLLM class, we use a wrapper to
|
||||||
|
give `self` the same interface as `TransformersForCausalLM`."""
|
||||||
|
|
||||||
|
class LanguageModelWrapper(TransformersForCausalLM):
|
||||||
|
def __init__(self, multimodal_model):
|
||||||
|
# Don't call super().__init__() to avoid re-initialization
|
||||||
|
self.__dict__.update(multimodal_model.__dict__)
|
||||||
|
|
||||||
|
model = getattr_iter(self.model, ("language_model", "text_model"), None)
|
||||||
|
|
||||||
|
return LanguageModelWrapper(self)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs):
|
def get_multimodal_embeddings(self, **kwargs):
|
||||||
pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
|
pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
|
||||||
@ -905,6 +924,7 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
num_image_patches = kwargs.pop("num_image_patches")
|
num_image_patches = kwargs.pop("num_image_patches")
|
||||||
|
kwargs.pop("token_type_ids", None) # used only in `forward`
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
||||||
|
|
||||||
@ -925,46 +945,4 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
|||||||
|
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
get_input_embeddings = SupportsMultiModal.get_input_embeddings
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
*,
|
|
||||||
is_multimodal: Optional[torch.Tensor] = None,
|
|
||||||
handle_oov_mm_token: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply token embeddings to `input_ids`.
|
|
||||||
|
|
||||||
If `multimodal_embeddings` is passed, scatter them into
|
|
||||||
`input_ids` according to the mask `is_multimodal`.
|
|
||||||
|
|
||||||
In case the multi-modal token IDs exceed the vocabulary size of
|
|
||||||
the language model, you can set `handle_oov_mm_token=False`
|
|
||||||
to avoid calling the language model's `get_input_embeddings` method
|
|
||||||
on those tokens.
|
|
||||||
"""
|
|
||||||
from .utils import _merge_multimodal_embeddings
|
|
||||||
|
|
||||||
inputs_embeds = self._get_text_embeddings(
|
|
||||||
input_ids,
|
|
||||||
self.model.get_input_embeddings(),
|
|
||||||
is_multimodal=is_multimodal,
|
|
||||||
handle_oov_mm_token=handle_oov_mm_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
if is_multimodal is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
|
||||||
"please update your model runner according to "
|
|
||||||
"https://github.com/vllm-project/vllm/pull/16229."
|
|
||||||
)
|
|
||||||
|
|
||||||
return _merge_multimodal_embeddings(
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
multimodal_embeddings=multimodal_embeddings,
|
|
||||||
is_multimodal=is_multimodal,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
from .interfaces import MixtureOfExperts
|
from .interfaces import MixtureOfExperts, SupportsMultiModal
|
||||||
from .transformers import (
|
from .transformers import (
|
||||||
TransformersBase,
|
TransformersBase,
|
||||||
TransformersForCausalLM,
|
TransformersForCausalLM,
|
||||||
@ -335,7 +335,5 @@ class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
|
|||||||
},
|
},
|
||||||
enable_if=can_enable_torch_compile,
|
enable_if=can_enable_torch_compile,
|
||||||
)
|
)
|
||||||
class TransformersMoEForMultimodalLM(
|
class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM):
|
||||||
TransformersMoEForCausalLM, TransformersForMultimodalLM
|
get_input_embeddings = SupportsMultiModal.get_input_embeddings
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user