mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 12:36:32 +08:00
[Bugfix] Fix handling of image embeds in models (#29480)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
452a7c9f7c
commit
e30859dff3
@ -48,7 +48,6 @@ from vllm.transformers_utils.configs.deepseek_vl2 import (
|
||||
)
|
||||
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
@ -595,19 +594,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: DeepseekVL2ImageInputs
|
||||
) -> list[torch.Tensor]:
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_data = image_input["data"]
|
||||
if is_list_of(image_data, torch.Tensor):
|
||||
# it's already a list of tensors
|
||||
return image_data
|
||||
if len(image_data.shape) == 3:
|
||||
# 3D tensor
|
||||
return list(torch.unbind(image_data, dim=0))
|
||||
raise ValueError(
|
||||
"We expect batched 2D tensors; "
|
||||
"this can be either a list of 2D tensors or a single 3D tensor."
|
||||
)
|
||||
return image_input["data"]
|
||||
|
||||
pixel_values = image_input["data"]
|
||||
images_spatial_crop = image_input["images_spatial_crop"]
|
||||
|
||||
@ -460,7 +460,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
image_input: LlavaNextImageInputs,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return [image_input["data"]]
|
||||
return image_input["data"]
|
||||
|
||||
patch_embeddings = self._process_image_pixels(image_input)
|
||||
|
||||
|
||||
@ -763,7 +763,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
|
||||
image_input: LlavaOnevisionImageInputs,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return [image_input["data"]]
|
||||
return image_input["data"]
|
||||
|
||||
patch_embeddings = self._process_image_pixels(image_input)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user