[Bugfix] should use stack instead of concat (#22972)

Signed-off-by: 947132885 <947132885@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
947132885 2025-08-17 16:46:36 +08:00 committed by GitHub
parent 4d4061b6e7
commit fe0411fc6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -694,6 +694,17 @@ class TransformersForCausalLM(TransformersBase):
return logits
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
"""Flatten until a list of tensors can be concatenated then do concat"""
def _can_concat(x: list[torch.Tensor]):
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
if _can_concat(x):
return torch.concat(x)
return flatten_and_concat(flatten_bn(x))
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
@ -766,8 +777,7 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
if isinstance(pixel_values, torch.Tensor):
pixel_values = flatten_bn(pixel_values).to(self.dtype)
elif is_list_of(pixel_values, torch.Tensor):
pixel_values = flatten_bn(flatten_bn(pixel_values),
concat=True).to(self.dtype)
pixel_values = flatten_and_concat(pixel_values).to(self.dtype)
else:
raise ValueError(
f"Unsupported pixel_values type {type(pixel_values)}. "