diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 840e40c946fcf..250d3968715ba 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -149,14 +149,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): config.vocab_size, logit_scale) self.sampler = Sampler() - def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape)[1:] != [ - 3, self.config.vision_config.image_size, - self.config.vision_config.image_size - ]: + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) raise ValueError( - "The expected image tensor shape is batch dimension plus " - "channel, height and width.") + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") return data @@ -173,7 +175,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_image_data(pixel_values), + data=self._validate_pixel_values(pixel_values), ) def _select_image_features(self, image_features: torch.Tensor, *, @@ -226,18 +228,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): One key thing to understand is the `input_ids` already accounts for the positions of the to-be-inserted image embeddings. + Concretely, consider a text prompt: - "\nUSER: What's the content of the image?\nASSISTANT:". + `"USER: \\nWhat's the content of the image?\\nASSISTANT:"`. + Tokenizer outputs: - [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, - 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. - The to-be-inserted image has a size of 576 (24 * 24) along the context - length dimension. - `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901, - 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, - 9047, 13566, 29901]. - There will be 576 `32000` in the `input_ids`. - (32000 is the token id for ``.) + `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879, + 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + additional image tokens (denoted as `32000`), resulting in: + `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, + 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, + 29901]`. + + We insert 575 tokens so that including the original image token in the + input, there are a total of 576 (24 * 24) image tokens, which + corresponds to the number of image tokens inputted to the language + model, i.e. the number of image tokens outputted by the visual encoder. This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. @@ -246,6 +255,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each input image. + + See also: + :class:`LlavaImageInputs` """ image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index ce1e9307a5ecd..7e06f1e95dab1 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -47,7 +47,8 @@ class LlavaNextImagePixelInputs(TypedDict): """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different for each batch. + Note that `num_patches` may be different for each batch, in which case + the data is passed as a list instead of a batched tensor. """ image_sizes: NotRequired[torch.Tensor] @@ -255,40 +256,20 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: - def _validate_shape(data: torch.Tensor): + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) - dim = data.dim() - height = width = self.config.vision_config.image_size - # All 4d image tensors have the same number of patches, - # so data is a 5d batch of these tensors - if dim == 5: - if list(data.shape)[2:] != [ - 3, self.config.vision_config.image_size, - self.config.vision_config.image_size - ]: - raise ValueError( - "Expected pixel value tensor in shape of: (batch size, " - f"patch number, 3, {height}, {width}), got {data.shape}" - ) + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) - # 4d image tensors have different number of patches, - # so data is each individual tensor. - elif dim == 4: - if list(data.shape)[1:] != [ - 3, self.config.vision_config.image_size, - self.config.vision_config.image_size - ]: - raise ValueError( - "Expected pixel value tensor in shape of: (patch " - f"number, 3, {height}, {width}), got {data.shape}") - else: + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) raise ValueError( - f"Invalid pixel value tensor of shape {data.shape}") + "The expected shape of pixel values in each batch element " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") - if isinstance(data, torch.Tensor): - _validate_shape(data) - else: - [_validate_shape(d) for d in data] + for d in data: + _validate_shape(d) return data @@ -464,18 +445,33 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): One key thing to understand is the `input_ids` already accounts for the positions of the to-be-inserted image embeddings. + Concretely, consider a text prompt: - "\nUSER: What's the content of the image?\nASSISTANT:". + `"A chat between a curious human and an artificial intelligence + assistant. The assistant gives helpful, detailed, and polite answers to + the human's questions. + USER: \\nWhat is shown in this image? ASSISTANT:"`. + Tokenizer outputs: - [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, - 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. - The to-be-inserted image has a size of 576 (24 * 24) along the context - length dimension. - `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901, - 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, - 9047, 13566, 29901]. - There will be 576 `32000` in the `input_ids`. - (32000 is the token id for ``.) + `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, + 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, + 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, + 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, + 9047, 13566, 29901]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + additional image tokens (denoted as `32000`), resulting in: + `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, + 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, + 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, + 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, + 319, 1799, 9047, 13566, 29901]`. + + Unlike in LLaVA-1.5, the number of image tokens inputted to the language + model depends on the original size of the input image. Including the + original image token in the input, the required number of image tokens + is given by :func:`get_llava_next_image_feature_size`. This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. @@ -484,15 +480,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each grid patch for each input image. - Expects a batch with shape `[1, num_patches, 3, h, w]`. image_sizes: The original `(height, width)` for each input image. - Expects a batch with shape `[1, 2]`. - + See also: - Each input maps to huggingface implementation, as follows: - - - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690 - - `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691 + :class:`LlavaNextImageInputs` """ image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index b087e485d9a8c..1c6bd106b53f5 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -263,7 +263,8 @@ class Phi3VImagePixelInputs(TypedDict): """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different for each batch. + Note that `num_patches` may be different for each batch, in which case + the data is passed as a list instead of a batched tensor. """ image_sizes: torch.Tensor @@ -466,8 +467,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: if list(data.shape[1:]) != [2]: raise ValueError( - f"The expected image sizes shape is batch dimension plus " - f"{[2]}. You supplied {data.shape}.") + f"The expected shape of image sizes is batch dimension plus " + f"{[2]}. You supplied {tuple(data.shape)}.") return data @@ -475,19 +476,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: - def _validate_shape(data: torch.Tensor): - if list(data.shape)[2:] != [ - 3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, - CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size - ]: - raise ValueError( - "The expected pixel value tensor shape is batch dimension " - "plus patch number, channel, height and width.") + h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size + expected_dims = (3, h, w) - if isinstance(data, torch.Tensor): - _validate_shape(data) - else: - [_validate_shape(d) for d in data] + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values in each batch element " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) return data