From ab656f2c2feb834a9d3d90456f684200edd2abd8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 18 Mar 2025 20:54:40 +0800 Subject: [PATCH] [Bugfix] Loosen type check to avoid errors in V1 (#15021) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/blip2.py | 12 +++++------- vllm/model_executor/models/chameleon.py | 7 +++---- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/glm4v.py | 2 +- vllm/model_executor/models/internvl.py | 6 ++++-- vllm/model_executor/models/llava_next_video.py | 15 ++++++--------- vllm/model_executor/models/llava_onevision.py | 5 +---- vllm/model_executor/models/paligemma.py | 10 ++++------ vllm/model_executor/models/qwen_vl.py | 6 +++--- 9 files changed, 28 insertions(+), 37 deletions(-) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 47362e3d89763..7adca4f0dc868 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors from .blip import BlipVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) # We use this internally as placeholders since there is no image token @@ -565,12 +565,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): return None if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Remove the N dimension until multiple images are supported. - pixel_values = pixel_values.squeeze(1) + pixel_values = flatten_bn(pixel_values, concat=True) return Blip2ImagePixelInputs( type="pixel_values", @@ -578,12 +577,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): + if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - # Remove the N dimension until multiple images are supported. - image_embeds = image_embeds.squeeze(1) + image_embeds = flatten_bn(image_embeds, concat=True) return Blip2ImageEmbeddingInputs( type="image_embeds", diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 66bf85b59d1e2..ebcd36148e073 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -39,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -972,12 +972,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values is None: return None - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Remove the N dimension until multiple images are supported. - pixel_values = pixel_values.squeeze(1) + pixel_values = flatten_bn(pixel_values, concat=True) return ChameleonImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 6ea8de8450bc7..0faf895964bb6 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -478,7 +478,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): flatten_bn(images_spatial_crop, concat=True))) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): + if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 9889b7e4de40a..c190a45855919 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -578,7 +578,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, pixel_values = kwargs.pop("pixel_values", None) if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e91d0ba1b382a..d31b623b5bc71 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -838,7 +838,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): return None if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): + if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") @@ -856,7 +856,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values_flat)}") - assert isinstance(image_num_patches, (torch.Tensor, list)) + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(pixel_values_flat)}") return InternVLImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 5eb56d6711f3b..8b1a8c9da6804 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -349,21 +349,18 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, List[b, Tensor(nb_frames, nb_channels, height, width)] } """ - pixel_values = kwargs.pop("pixel_values_videos", None) + pixel_values_videos = kwargs.pop("pixel_values_videos", None) - if pixel_values is None: + if pixel_values_videos is None: return None - if not (is_list_of(pixel_values, - (torch.Tensor)) # different shape videos - or isinstance(pixel_values, - torch.Tensor)): # same shape videos - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + if not isinstance(pixel_values_videos, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel_values_videos. " + f"Got type: {type(pixel_values_videos)}") return LlavaNextVideoPixelInputs( type="pixel_values_videos", - data=pixel_values, + data=pixel_values_videos, ) def _select_image_features(self, image_features: torch.Tensor, *, diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 52ec0abcdc5b5..6a2328f950b84 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -574,10 +574,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values_videos is None: return None - if not (is_list_of(pixel_values_videos, - torch.Tensor) # different shape videos - or isinstance(pixel_values_videos, - torch.Tensor)): # same shape videos + if not isinstance(pixel_values_videos, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel_values_videos. " f"Got type: {type(pixel_values_videos)}") diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8a773607ce4ed..6fedb8c819849 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -23,7 +23,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, init_vllm_registered_model, +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vision_encoder_info @@ -270,12 +270,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Remove the N dimension until multiple images are supported. - pixel_values = pixel_values.squeeze(1) + pixel_values = flatten_bn(pixel_values, concat=True) return PaliGemmaImagePixelInputs( type="pixel_values", @@ -287,8 +286,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - # Remove the N dimension until multiple images are supported. - image_embeds = image_embeds.squeeze(1) + image_embeds = flatten_bn(image_embeds, concat=True) return PaliGemmaImageEmbeddingInputs( type="image_embeds", diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 1a39d2e74b1ee..4e9d02ae0abdb 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -711,7 +711,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, image_embeds = kwargs.pop("image_embeds", None) if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") @@ -722,13 +722,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): + if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") return QwenImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=flatten_bn(image_embeds, concat=True), ) return None