[Model][Bugfix] Fix batching with multi-image in PixtralHF (#9518)

This commit is contained in:
Michael Goin 2024-10-21 14:20:07 -04:00 committed by GitHub
parent ec6bd6c4c6
commit 5241aa1494
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 16 deletions

View File

@ -287,6 +287,34 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return data
def _validate_image_sizes(self, images: List[torch.Tensor],
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(sizes, list):
sizes = [sizes]
total_images = sum(size.numel() // 2 for size in sizes)
if total_images != len(images):
raise ValueError("Mismatch in number of images. "
f"Expected {total_images}, got {len(images)}")
img_idx = 0
for size in sizes:
# Flatten the size tensor to a list of (height, width) pairs
size = size.view(-1, 2).tolist()
for expected_h, expected_w in size:
if img_idx >= len(images):
raise ValueError("Ran out of images before sizes. "
f"{img_idx} >= {len(images)}")
img = images[img_idx]
if img.shape[-2:] != (expected_h, expected_w):
raise ValueError(
"Image size mismatch. Expected "
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
if img.shape[-3] != 3:
raise ValueError("Image channel mismatch. Expected 3, "
f"got {img.shape[-3]}")
img_idx += 1
return images
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -305,20 +333,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
if isinstance(images, torch.Tensor):
# if passed as batch take all images
NN, N, B, C, W, H = images.shape
images = images.reshape(NN * N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
while isinstance(images, list) and len(images) == 1:
images = images[0]
# TODO: Add validation based on image_sizes
def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")
# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
data=self._validate_image_sizes(images, image_sizes),
)
return LlavaImagePixelInputs(

View File

@ -907,17 +907,18 @@ class PixtralHFVisionModel(nn.Module):
) -> torch.Tensor:
"""
Args:
pixel_values: tensor of token features for
all tokens of all images of shape (N_toks, D)
pixel_values: Each image to be processed will be a separate tensor
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(
img.reshape(-1, img.shape[-3], img.shape[-2],
img.shape[-1]).to(self.dtype))
self.patch_conv(img.unsqueeze(0).to(self.dtype))
for img in pixel_values
]