mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 15:12:38 +08:00
[VLM] Cleanup validation and update docs (#6149)
This commit is contained in:
parent
a41357e941
commit
ea4b570483
@ -149,14 +149,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape)[1:] != [
|
||||
3, self.config.vision_config.image_size,
|
||||
self.config.vision_config.image_size
|
||||
]:
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected image tensor shape is batch dimension plus "
|
||||
"channel, height and width.")
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
@ -173,7 +175,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_data(pixel_values),
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
@ -226,18 +228,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
positions of the to-be-inserted image embeddings.
|
||||
|
||||
Concretely, consider a text prompt:
|
||||
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
|
||||
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
|
||||
|
||||
Tokenizer outputs:
|
||||
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
|
||||
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
|
||||
The to-be-inserted image has a size of 576 (24 * 24) along the context
|
||||
length dimension.
|
||||
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
|
||||
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
|
||||
9047, 13566, 29901].
|
||||
There will be 576 `32000` in the `input_ids`.
|
||||
(32000 is the token id for `<image>`.)
|
||||
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
|
||||
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
|
||||
|
||||
To reserve space in KV cache, we have to insert placeholder tokens
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
additional image tokens (denoted as `32000`), resulting in:
|
||||
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
|
||||
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
|
||||
29901]`.
|
||||
|
||||
We insert 575 tokens so that including the original image token in the
|
||||
input, there are a total of 576 (24 * 24) image tokens, which
|
||||
corresponds to the number of image tokens inputted to the language
|
||||
model, i.e. the number of image tokens outputted by the visual encoder.
|
||||
|
||||
This way, the `positions` and `attn_metadata` are consistent
|
||||
with the `input_ids`.
|
||||
@ -246,6 +255,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: The pixels in each input image.
|
||||
|
||||
See also:
|
||||
:class:`LlavaImageInputs`
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
|
||||
@ -47,7 +47,8 @@ class LlavaNextImagePixelInputs(TypedDict):
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
Note that `num_patches` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
@ -255,40 +256,20 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
def _validate_shape(data: torch.Tensor):
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
dim = data.dim()
|
||||
height = width = self.config.vision_config.image_size
|
||||
# All 4d image tensors have the same number of patches,
|
||||
# so data is a 5d batch of these tensors
|
||||
if dim == 5:
|
||||
if list(data.shape)[2:] != [
|
||||
3, self.config.vision_config.image_size,
|
||||
self.config.vision_config.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"Expected pixel value tensor in shape of: (batch size, "
|
||||
f"patch number, 3, {height}, {width}), got {data.shape}"
|
||||
)
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[1:])
|
||||
|
||||
# 4d image tensors have different number of patches,
|
||||
# so data is each individual tensor.
|
||||
elif dim == 4:
|
||||
if list(data.shape)[1:] != [
|
||||
3, self.config.vision_config.image_size,
|
||||
self.config.vision_config.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"Expected pixel value tensor in shape of: (patch "
|
||||
f"number, 3, {height}, {width}), got {data.shape}")
|
||||
else:
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"Invalid pixel value tensor of shape {data.shape}")
|
||||
"The expected shape of pixel values in each batch element "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
_validate_shape(data)
|
||||
else:
|
||||
[_validate_shape(d) for d in data]
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
@ -464,18 +445,33 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
positions of the to-be-inserted image embeddings.
|
||||
|
||||
Concretely, consider a text prompt:
|
||||
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
|
||||
`"A chat between a curious human and an artificial intelligence
|
||||
assistant. The assistant gives helpful, detailed, and polite answers to
|
||||
the human's questions.
|
||||
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
|
||||
|
||||
Tokenizer outputs:
|
||||
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
|
||||
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
|
||||
The to-be-inserted image has a size of 576 (24 * 24) along the context
|
||||
length dimension.
|
||||
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
|
||||
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
|
||||
9047, 13566, 29901].
|
||||
There will be 576 `32000` in the `input_ids`.
|
||||
(32000 is the token id for `<image>`.)
|
||||
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
|
||||
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
|
||||
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
|
||||
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
|
||||
9047, 13566, 29901]`.
|
||||
|
||||
To reserve space in KV cache, we have to insert placeholder tokens
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
additional image tokens (denoted as `32000`), resulting in:
|
||||
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
|
||||
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
|
||||
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
|
||||
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
|
||||
319, 1799, 9047, 13566, 29901]`.
|
||||
|
||||
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
|
||||
model depends on the original size of the input image. Including the
|
||||
original image token in the input, the required number of image tokens
|
||||
is given by :func:`get_llava_next_image_feature_size`.
|
||||
|
||||
This way, the `positions` and `attn_metadata` are consistent
|
||||
with the `input_ids`.
|
||||
@ -484,15 +480,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: The pixels in each grid patch for each input image.
|
||||
Expects a batch with shape `[1, num_patches, 3, h, w]`.
|
||||
image_sizes: The original `(height, width)` for each input image.
|
||||
Expects a batch with shape `[1, 2]`.
|
||||
|
||||
|
||||
See also:
|
||||
Each input maps to huggingface implementation, as follows:
|
||||
|
||||
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690
|
||||
- `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691
|
||||
:class:`LlavaNextImageInputs`
|
||||
"""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
|
||||
@ -263,7 +263,8 @@ class Phi3VImagePixelInputs(TypedDict):
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
Note that `num_patches` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
@ -466,8 +467,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
f"The expected image sizes shape is batch dimension plus "
|
||||
f"{[2]}. You supplied {data.shape}.")
|
||||
f"The expected shape of image sizes is batch dimension plus "
|
||||
f"{[2]}. You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
@ -475,19 +476,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
def _validate_shape(data: torch.Tensor):
|
||||
if list(data.shape)[2:] != [
|
||||
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"The expected pixel value tensor shape is batch dimension "
|
||||
"plus patch number, channel, height and width.")
|
||||
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
_validate_shape(data)
|
||||
else:
|
||||
[_validate_shape(d) for d in data]
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values in each batch element "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user