mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Model][Bugfix] Fix batching with multi-image in PixtralHF (#9518)
This commit is contained in:
parent
ec6bd6c4c6
commit
5241aa1494
@ -287,6 +287,34 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return data
|
||||
|
||||
def _validate_image_sizes(self, images: List[torch.Tensor],
|
||||
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
if not isinstance(sizes, list):
|
||||
sizes = [sizes]
|
||||
|
||||
total_images = sum(size.numel() // 2 for size in sizes)
|
||||
if total_images != len(images):
|
||||
raise ValueError("Mismatch in number of images. "
|
||||
f"Expected {total_images}, got {len(images)}")
|
||||
img_idx = 0
|
||||
for size in sizes:
|
||||
# Flatten the size tensor to a list of (height, width) pairs
|
||||
size = size.view(-1, 2).tolist()
|
||||
for expected_h, expected_w in size:
|
||||
if img_idx >= len(images):
|
||||
raise ValueError("Ran out of images before sizes. "
|
||||
f"{img_idx} >= {len(images)}")
|
||||
img = images[img_idx]
|
||||
if img.shape[-2:] != (expected_h, expected_w):
|
||||
raise ValueError(
|
||||
"Image size mismatch. Expected "
|
||||
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
|
||||
if img.shape[-3] != 3:
|
||||
raise ValueError("Image channel mismatch. Expected 3, "
|
||||
f"got {img.shape[-3]}")
|
||||
img_idx += 1
|
||||
return images
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -305,20 +333,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# so we need to produce a list of tensors
|
||||
if image_sizes is not None:
|
||||
images = pixel_values
|
||||
if isinstance(images, torch.Tensor):
|
||||
# if passed as batch take all images
|
||||
NN, N, B, C, W, H = images.shape
|
||||
images = images.reshape(NN * N * B, C, W, H)
|
||||
images = [images[i] for i in range(images.size(0))]
|
||||
elif isinstance(images, list):
|
||||
# if passed as list flatten lists of tensors
|
||||
while isinstance(images, list) and len(images) == 1:
|
||||
images = images[0]
|
||||
|
||||
# TODO: Add validation based on image_sizes
|
||||
def flatten_to_3d_tensors(item):
|
||||
if isinstance(item, torch.Tensor):
|
||||
if item.dim() >= 3:
|
||||
return [t for t in item.view(-1, *item.shape[-3:])]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected tensor dimension: {item.dim()}")
|
||||
elif isinstance(item, list):
|
||||
return [
|
||||
t for subitem in item
|
||||
for t in flatten_to_3d_tensors(subitem)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(item)}")
|
||||
|
||||
# Restructure the batched images into a list of lists of images
|
||||
images = flatten_to_3d_tensors(pixel_values)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=images,
|
||||
data=self._validate_image_sizes(images, image_sizes),
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
|
||||
@ -907,17 +907,18 @@ class PixtralHFVisionModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pixel_values: tensor of token features for
|
||||
all tokens of all images of shape (N_toks, D)
|
||||
pixel_values: Each image to be processed will be a separate tensor
|
||||
in pixel_values. This means it will be a list of tensors
|
||||
because multiple requests batched can have multiple images,
|
||||
each with their own shape potentially
|
||||
|
||||
Returns:
|
||||
image_features: tensor of token features for
|
||||
all tokens of all images of shape (N_toks, D)
|
||||
"""
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [
|
||||
self.patch_conv(
|
||||
img.reshape(-1, img.shape[-3], img.shape[-2],
|
||||
img.shape[-1]).to(self.dtype))
|
||||
self.patch_conv(img.unsqueeze(0).to(self.dtype))
|
||||
for img in pixel_values
|
||||
]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user