[VLM] Cleanup validation and update docs (#6149)

This commit is contained in:
Cyrus Leung 2024-07-05 13:49:38 +08:00 committed by GitHub
parent a41357e941
commit ea4b570483
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 81 deletions

View File

@ -149,14 +149,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
config.vocab_size, logit_scale)
self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
"The expected image tensor shape is batch dimension plus "
"channel, height and width.")
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
@ -173,7 +175,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(pixel_values),
data=self._validate_pixel_values(pixel_values),
)
def _select_image_features(self, image_features: torch.Tensor, *,
@ -226,18 +228,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
29901]`.
We insert 575 tokens so that including the original image token in the
input, there are a total of 576 (24 * 24) image tokens, which
corresponds to the number of image tokens inputted to the language
model, i.e. the number of image tokens outputted by the visual encoder.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
@ -246,6 +255,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`LlavaImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -47,7 +47,8 @@ class LlavaNextImagePixelInputs(TypedDict):
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: NotRequired[torch.Tensor]
@ -255,40 +256,20 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
dim = data.dim()
height = width = self.config.vision_config.image_size
# All 4d image tensors have the same number of patches,
# so data is a 5d batch of these tensors
if dim == 5:
if list(data.shape)[2:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (batch size, "
f"patch number, 3, {height}, {width}), got {data.shape}"
)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
# 4d image tensors have different number of patches,
# so data is each individual tensor.
elif dim == 4:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (patch "
f"number, 3, {height}, {width}), got {data.shape}")
else:
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
f"Invalid pixel value tensor of shape {data.shape}")
"The expected shape of pixel values in each batch element "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
for d in data:
_validate_shape(d)
return data
@ -464,18 +445,33 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
`"A chat between a curious human and an artificial intelligence
assistant. The assistant gives helpful, detailed, and polite answers to
the human's questions.
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
319, 1799, 9047, 13566, 29901]`.
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
model depends on the original size of the input image. Including the
original image token in the input, the required number of image tokens
is given by :func:`get_llava_next_image_feature_size`.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
@ -484,15 +480,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each grid patch for each input image.
Expects a batch with shape `[1, num_patches, 3, h, w]`.
image_sizes: The original `(height, width)` for each input image.
Expects a batch with shape `[1, 2]`.
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690
- `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691
:class:`LlavaNextImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -263,7 +263,8 @@ class Phi3VImagePixelInputs(TypedDict):
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
@ -466,8 +467,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
f"The expected shape of image sizes is batch dimension plus "
f"{[2]}. You supplied {tuple(data.shape)}.")
return data
@ -475,19 +476,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
if list(data.shape)[2:] != [
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
]:
raise ValueError(
"The expected pixel value tensor shape is batch dimension "
"plus patch number, channel, height and width.")
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
expected_dims = (3, h, w)
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each batch element "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data