mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 09:51:20 +08:00
[Bugfix] Fix mrope in Transformers Backend (#26087)
Signed-off-by: raushan <raushan@huggingface.co> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
0340f45553
commit
ab5e7d93f4
@ -222,8 +222,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
vllm_runner_kwargs={
|
vllm_runner_kwargs={
|
||||||
"model_impl": "transformers",
|
"model_impl": "transformers",
|
||||||
},
|
},
|
||||||
# FIXME: Investigate mrope issue
|
marks=[large_gpu_mark(min_gb=32)],
|
||||||
marks=[large_gpu_mark(min_gb=32), pytest.mark.skip(reason="Mrope issue")],
|
|
||||||
),
|
),
|
||||||
#### Extended model tests
|
#### Extended model tests
|
||||||
"aria": VLMTestInfo(
|
"aria": VLMTestInfo(
|
||||||
|
|||||||
@ -79,7 +79,6 @@ from .utils import (
|
|||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
WeightsMapper,
|
WeightsMapper,
|
||||||
flatten_bn,
|
|
||||||
make_empty_intermediate_tensors_factory,
|
make_empty_intermediate_tensors_factory,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
@ -347,12 +346,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
hf_inputs,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
num_image_patches: torch.Tensor = None,
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
):
|
|
||||||
# HF Processors always return a mask but vLLM doesn't need it
|
# HF Processors always return a mask but vLLM doesn't need it
|
||||||
hf_inputs.pop("attention_mask", None)
|
hf_inputs.pop("attention_mask", None)
|
||||||
|
num_image_patches = hf_inputs.get("num_image_patches")
|
||||||
mm_fields = {
|
mm_fields = {
|
||||||
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
|
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
|
||||||
for key in hf_inputs
|
for key in hf_inputs
|
||||||
@ -360,41 +359,24 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|||||||
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
|
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", num_image_patches
|
"image", num_image_patches
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Keep these as batched, as they always have batch size as first dim
|
||||||
|
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
|
||||||
|
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
|
||||||
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
|
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
|
||||||
return mm_fields
|
return mm_fields
|
||||||
|
|
||||||
def _apply_hf_processor_text_mm(
|
def _get_hf_mm_data(
|
||||||
self,
|
self,
|
||||||
prompt_text: str,
|
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
) -> tuple[Mapping[str, object], Mapping[str, object]]:
|
||||||
tokenization_kwargs: Mapping[str, object],
|
|
||||||
) -> tuple[list[int], BatchFeature, bool]:
|
|
||||||
"""
|
"""
|
||||||
Apply the HF processor on the prompt text and multi-modal data
|
In contrast to the base class, this method always adds
|
||||||
together.
|
`return_mm_token_type_ids` to the processor data
|
||||||
|
|
||||||
In addition, return whether prompt replacements have been applied.
|
|
||||||
"""
|
"""
|
||||||
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
|
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
|
||||||
processor_data["return_mm_token_type_ids"] = True
|
processor_data["return_mm_token_type_ids"] = True
|
||||||
|
return processor_data, passthrough_data
|
||||||
processed_data = self._call_hf_processor(
|
|
||||||
prompt=prompt_text,
|
|
||||||
mm_data=processor_data,
|
|
||||||
mm_kwargs=hf_processor_mm_kwargs,
|
|
||||||
tok_kwargs=tokenization_kwargs,
|
|
||||||
)
|
|
||||||
processed_data.update(passthrough_data)
|
|
||||||
|
|
||||||
(prompt_ids,) = processed_data.pop("input_ids").tolist()
|
|
||||||
mm_token_type_ids = (
|
|
||||||
processed_data.pop("mm_token_type_ids")
|
|
||||||
if "mm_token_type_ids" in processed_data
|
|
||||||
else processed_data.pop("token_type_ids")
|
|
||||||
) # for gemma3 only
|
|
||||||
|
|
||||||
return prompt_ids, processed_data, mm_token_type_ids
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -421,18 +403,28 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|||||||
# into string
|
# into string
|
||||||
prompt = hf_processor.decode(prompt)
|
prompt = hf_processor.decode(prompt)
|
||||||
|
|
||||||
(prompt_ids, processed_data, mm_token_type_ids) = (
|
# Bypass cached processor and always apply to the full set of mm inputs
|
||||||
self._apply_hf_processor_text_mm(
|
# NOTE: we can't just set caching=False because base class method
|
||||||
prompt_text=prompt,
|
# transforms outputs to `MultiModalKwargs` which is not going to
|
||||||
mm_items=mm_items,
|
# work for Transformers. We have a lot of logic tied to
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
# `mm_tokens_per_modality` below
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
|
||||||
)
|
prompt_text=prompt,
|
||||||
|
mm_items=mm_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# HF processor will return `mm_token_type_ids` from which
|
# For gemma3 we check `token_type_ids` as the key
|
||||||
# we can infer mm_placeholders. Until then hardcode to make code run
|
token_type_key = (
|
||||||
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
|
"mm_token_type_ids"
|
||||||
|
if "mm_token_type_ids" in processed_data
|
||||||
|
else "token_type_ids"
|
||||||
|
)
|
||||||
|
mm_token_type_ids = processed_data.pop(token_type_key)
|
||||||
|
|
||||||
|
# We can infer vLLM style placeholder from token type ids, if we split
|
||||||
|
# it for each input `mm_data`.
|
||||||
mm_positions = torch.where(mm_token_type_ids == 1)[1]
|
mm_positions = torch.where(mm_token_type_ids == 1)[1]
|
||||||
images = mm_items.get_items("image", ImageProcessorItems)
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
multimodal_config = self.info.ctx.model_config.multimodal_config
|
multimodal_config = self.info.ctx.model_config.multimodal_config
|
||||||
@ -462,17 +454,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|||||||
]
|
]
|
||||||
mm_placeholders = {"image": ranges}
|
mm_placeholders = {"image": ranges}
|
||||||
|
|
||||||
num_image_patches = (
|
processed_data["num_image_patches"] = torch.tensor(
|
||||||
torch.tensor(mm_tokens_per_modality["num_image_patches"])
|
mm_tokens_per_modality["num_image_patches"]
|
||||||
if "num_image_patches" in mm_tokens_per_modality
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
processed_data["num_image_patches"] = num_image_patches
|
|
||||||
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||||
processed_data,
|
processed_data,
|
||||||
self._get_mm_fields_config(
|
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
||||||
processed_data, hf_processor_mm_kwargs, num_image_patches
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use overrides if provided; fallback to data-dependent hashing.
|
# Use overrides if provided; fallback to data-dependent hashing.
|
||||||
@ -531,8 +518,6 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
self.ignore_unexpected_suffixes.append(".bias")
|
self.ignore_unexpected_suffixes.append(".bias")
|
||||||
|
|
||||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||||
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
|
||||||
# method once its checks are fixed in Transformers.
|
|
||||||
self.text_config._attn_implementation = "vllm"
|
self.text_config._attn_implementation = "vllm"
|
||||||
with init_on_device_without_buffers("meta"):
|
with init_on_device_without_buffers("meta"):
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
@ -844,17 +829,6 @@ class TransformersForCausalLM(TransformersBase):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
|
|
||||||
"""Flatten until a list of tensors can be concatenated then do concat"""
|
|
||||||
|
|
||||||
def _can_concat(x: list[torch.Tensor]):
|
|
||||||
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
|
|
||||||
|
|
||||||
if _can_concat(x):
|
|
||||||
return torch.concat(x)
|
|
||||||
return flatten_and_concat(flatten_bn(x))
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
MultiModalProcessor,
|
MultiModalProcessor,
|
||||||
info=MultiModalProcessingInfo,
|
info=MultiModalProcessingInfo,
|
||||||
@ -935,9 +909,6 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
|||||||
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
||||||
|
|
||||||
if isinstance(vision_embeddings, torch.Tensor):
|
if isinstance(vision_embeddings, torch.Tensor):
|
||||||
if isinstance(num_image_patches, list):
|
|
||||||
num_image_patches = torch.cat(num_image_patches)
|
|
||||||
|
|
||||||
if vision_embeddings.ndim == 2:
|
if vision_embeddings.ndim == 2:
|
||||||
vision_embeddings = vision_embeddings.unsqueeze(0)
|
vision_embeddings = vision_embeddings.unsqueeze(0)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user