mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:34:54 +08:00
[Models] Optimise and simplify _validate_and_reshape_mm_tensor (#24742)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
684b6870e1
commit
57f94e88ea
@ -1338,7 +1338,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||||
f"Got ndim: {mm_input.ndim} "
|
f"Got ndim: {mm_input.ndim} "
|
||||||
f"(shape={mm_input.shape})")
|
f"(shape={mm_input.shape})")
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||||
else:
|
else:
|
||||||
return torch.concat(mm_input)
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
|
|||||||
@ -1378,7 +1378,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||||
f"Got ndim: {mm_input.ndim} "
|
f"Got ndim: {mm_input.ndim} "
|
||||||
f"(shape={mm_input.shape})")
|
f"(shape={mm_input.shape})")
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||||
else:
|
else:
|
||||||
return torch.concat(mm_input)
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
|
|||||||
@ -1611,12 +1611,12 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||||
f"Got ndim: {mm_input.ndim} "
|
f"Got ndim: {mm_input.ndim} "
|
||||||
f"(shape={mm_input.shape})")
|
f"(shape={mm_input.shape})")
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||||
elif is_list_of(mm_input, torch.Tensor):
|
elif is_list_of(mm_input, torch.Tensor):
|
||||||
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
|
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
|
||||||
for p in mm_input):
|
for p in mm_input):
|
||||||
return mm_input
|
return mm_input
|
||||||
return torch.concat(list(mm_input))
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[KeyeImageInputs]:
|
self, **kwargs: object) -> Optional[KeyeImageInputs]:
|
||||||
|
|||||||
@ -491,14 +491,14 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
if mm_input.ndim == expected_dim:
|
if mm_input.ndim == expected_dim:
|
||||||
return mm_input
|
return mm_input
|
||||||
elif mm_input.ndim == expected_dim + 1:
|
elif mm_input.ndim == expected_dim + 1:
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{name} should be {expected_dim}D or "
|
f"{name} should be {expected_dim}D or "
|
||||||
f"batched {expected_dim}D tensor."
|
f"batched {expected_dim}D tensor."
|
||||||
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
|
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
|
||||||
else:
|
else:
|
||||||
return torch.concat(list(mm_input))
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
|
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
|
||||||
|
|||||||
@ -669,7 +669,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError(f"Incorrect type of {name}. "
|
raise ValueError(f"Incorrect type of {name}. "
|
||||||
f"Got type: {type(mm_input)}")
|
f"Got type: {type(mm_input)}")
|
||||||
if isinstance(mm_input, torch.Tensor):
|
if isinstance(mm_input, torch.Tensor):
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||||
else:
|
else:
|
||||||
return torch.concat(mm_input)
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
|
|||||||
@ -551,6 +551,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
raise ValueError(f"Incorrect type of {name}. "
|
raise ValueError(f"Incorrect type of {name}. "
|
||||||
f"Got type: {type(mm_input)}")
|
f"Got type: {type(mm_input)}")
|
||||||
if isinstance(mm_input, torch.Tensor):
|
if isinstance(mm_input, torch.Tensor):
|
||||||
|
if dim == 0:
|
||||||
|
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||||
return torch.concat(list(mm_input), dim=dim)
|
return torch.concat(list(mm_input), dim=dim)
|
||||||
else:
|
else:
|
||||||
return torch.concat(mm_input, dim=dim)
|
return torch.concat(mm_input, dim=dim)
|
||||||
|
|||||||
@ -986,7 +986,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||||
f"Got ndim: {mm_input.ndim} "
|
f"Got ndim: {mm_input.ndim} "
|
||||||
f"(shape={mm_input.shape})")
|
f"(shape={mm_input.shape})")
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||||
else:
|
else:
|
||||||
return torch.concat(mm_input)
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
|
|||||||
@ -342,7 +342,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError(f"Incorrect type of {name}. "
|
raise ValueError(f"Incorrect type of {name}. "
|
||||||
f"Got type: {type(mm_input)}")
|
f"Got type: {type(mm_input)}")
|
||||||
if isinstance(mm_input, torch.Tensor):
|
if isinstance(mm_input, torch.Tensor):
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||||
else:
|
else:
|
||||||
return torch.concat(mm_input)
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
|
|||||||
@ -1167,7 +1167,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||||
f"Got ndim: {mm_input.ndim} "
|
f"Got ndim: {mm_input.ndim} "
|
||||||
f"(shape={mm_input.shape})")
|
f"(shape={mm_input.shape})")
|
||||||
return torch.concat(list(mm_input))
|
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||||
else:
|
else:
|
||||||
return torch.concat(mm_input)
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user