mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:25:39 +08:00
[Bugfix] should use stack instead of concat (#22972)
Signed-off-by: 947132885 <947132885@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
4d4061b6e7
commit
fe0411fc6f
@ -694,6 +694,17 @@ class TransformersForCausalLM(TransformersBase):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""Flatten until a list of tensors can be concatenated then do concat"""
|
||||||
|
|
||||||
|
def _can_concat(x: list[torch.Tensor]):
|
||||||
|
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
|
||||||
|
|
||||||
|
if _can_concat(x):
|
||||||
|
return torch.concat(x)
|
||||||
|
return flatten_and_concat(flatten_bn(x))
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
MultiModalProcessor,
|
MultiModalProcessor,
|
||||||
info=MultiModalProcessingInfo,
|
info=MultiModalProcessingInfo,
|
||||||
@ -766,8 +777,7 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
|||||||
if isinstance(pixel_values, torch.Tensor):
|
if isinstance(pixel_values, torch.Tensor):
|
||||||
pixel_values = flatten_bn(pixel_values).to(self.dtype)
|
pixel_values = flatten_bn(pixel_values).to(self.dtype)
|
||||||
elif is_list_of(pixel_values, torch.Tensor):
|
elif is_list_of(pixel_values, torch.Tensor):
|
||||||
pixel_values = flatten_bn(flatten_bn(pixel_values),
|
pixel_values = flatten_and_concat(pixel_values).to(self.dtype)
|
||||||
concat=True).to(self.dtype)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported pixel_values type {type(pixel_values)}. "
|
f"Unsupported pixel_values type {type(pixel_values)}. "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user