From e30859dff3d93bd3e289f6e996afbb59ac475b72 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 26 Nov 2025 21:00:15 +0800 Subject: [PATCH] [Bugfix] Fix handling of image embeds in models (#29480) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/deepseek_vl2.py | 15 ++------------- vllm/model_executor/models/llava_next.py | 2 +- vllm/model_executor/models/llava_onevision.py | 2 +- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index e7b48e0f4e554..1b6e4110039c4 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -48,7 +48,6 @@ from vllm.transformers_utils.configs.deepseek_vl2 import ( ) from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils.collection_utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.torch_utils import set_default_torch_dtype @@ -595,19 +594,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def _process_image_input( self, image_input: DeepseekVL2ImageInputs - ) -> list[torch.Tensor]: + ) -> torch.Tensor | list[torch.Tensor]: if image_input["type"] == "image_embeds": - image_data = image_input["data"] - if is_list_of(image_data, torch.Tensor): - # it's already a list of tensors - return image_data - if len(image_data.shape) == 3: - # 3D tensor - return list(torch.unbind(image_data, dim=0)) - raise ValueError( - "We expect batched 2D tensors; " - "this can be either a list of 2D tensors or a single 3D tensor." - ) + return image_input["data"] pixel_values = image_input["data"] images_spatial_crop = image_input["images_spatial_crop"] diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 98b1b46045c3d..b995cac47ac1c 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -460,7 +460,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP image_input: LlavaNextImageInputs, ) -> torch.Tensor | list[torch.Tensor]: if image_input["type"] == "image_embeds": - return [image_input["data"]] + return image_input["data"] patch_embeddings = self._process_image_pixels(image_input) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 322bde94ff66d..4e243ade68358 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -763,7 +763,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp image_input: LlavaOnevisionImageInputs, ) -> torch.Tensor | list[torch.Tensor]: if image_input["type"] == "image_embeds": - return [image_input["data"]] + return image_input["data"] patch_embeddings = self._process_image_pixels(image_input)