From 57f94e88ea1ed2e48ea8ea9b01e9591f0d79557b Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Fri, 12 Sep 2025 16:37:37 +0100 Subject: [PATCH] [Models] Optimise and simplify `_validate_and_reshape_mm_tensor` (#24742) Signed-off-by: Lukas Geiger --- vllm/model_executor/models/ernie45_vl.py | 2 +- vllm/model_executor/models/glm4_1v.py | 2 +- vllm/model_executor/models/keye.py | 4 ++-- vllm/model_executor/models/keye_vl1_5.py | 4 ++-- vllm/model_executor/models/midashenglm.py | 2 +- vllm/model_executor/models/qwen2_5_omni_thinker.py | 2 ++ vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/model_executor/models/qwen2_audio.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- 9 files changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index bcff65a717ab..f49dd36b0ab9 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1338,7 +1338,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 869287fc0268..4ed07bd060cf 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1378,7 +1378,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 04824db1b6dd..cb4cd60a8917 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1611,12 +1611,12 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) elif is_list_of(mm_input, torch.Tensor): if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 for p in mm_input): return mm_input - return torch.concat(list(mm_input)) + return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 605c6d3eaf64..b3c44d132435 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -491,14 +491,14 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, if mm_input.ndim == expected_dim: return mm_input elif mm_input.ndim == expected_dim + 1: - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: raise ValueError( f"{name} should be {expected_dim}D or " f"batched {expected_dim}D tensor." f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})") else: - return torch.concat(list(mm_input)) + return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]: diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 858d4e7e34cf..e314ae357ecd 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -669,7 +669,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index f8a943d4cab3..a7e71309b607 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -551,6 +551,8 @@ class Qwen2_5OmniConditionalGenerationMixin: raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): + if dim == 0: + return mm_input.reshape(-1, *mm_input.shape[2:]) return torch.concat(list(mm_input), dim=dim) else: return torch.concat(mm_input, dim=dim) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 98f9c0cf4c16..5d35a7054659 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -986,7 +986,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 54ec7b862748..c797b71b5d2e 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -342,7 +342,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 89af79c3b5fd..cf15dfa67743 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1167,7 +1167,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input)