[Bugfix] Fix Idefics3 fails during multi-image inference (#11080)

Signed-off-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
B-201 2024-12-11 17:27:07 +08:00 committed by GitHub
parent 61b1d2f6ae
commit 2e32f5d28d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,7 +60,8 @@ class Idefics3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Shape: `(batch_size * num_images * num_patches,
num_channels, height, width)`
"""
pixel_attention_mask: Optional[torch.BoolTensor]
@ -520,13 +521,17 @@ class Idefics3Model(nn.Module):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Idefics3ImagePixelInputs(type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values,
concat=True)),
pixel_attention_mask=flatten_bn(
pixel_attention_mask,
concat=True))
if isinstance(pixel_values, list):
pixel_values = torch.cat(pixel_values, dim=1)
pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
else:
pixel_values = flatten_bn(pixel_values)
pixel_attention_mask = flatten_bn(pixel_attention_mask)
return Idefics3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask)
raise AssertionError("This line should be unreachable.")