diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e8a5b6237d4d..fbd763809728 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch from torch import nn @@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor, return inputs_embeds +class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channels, height, width)""" + + +class LlavaImageFeatureInputs(TypedDict): + type: Literal["image_features"] + data: torch.Tensor + """Shape: (batch_size, image_feature_size, hidden_size)""" + + +LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] + + class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, @@ -102,6 +117,90 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): config.vocab_size, logit_scale) self.sampler = Sampler() + def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != list( + self.vision_language_config.image_input_shape[1:]): + raise ValueError( + f"The expected image tensor shape is batch dimension plus " + f"{self.vision_language_config.image_input_shape[1:]}. " + f"You supplied {data.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + + return data + + def _parse_and_validate_image_input( + self, data: object) -> Optional[LlavaImageInputs]: + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if data is None: + return None + + if expected_input_type == ImageInputType.PIXEL_VALUES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image pixel vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImagePixelInputs( + type="pixel_values", + data=self._validate_image_data(data), + ) + elif expected_input_type == ImageInputType.IMAGE_FEATURES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image feature vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImageFeatureInputs( + type="image_features", + data=self._validate_image_data(data), + ) + + return None + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. + image_outputs = vision_tower(pixel_values.to(vision_tower.device), + output_hidden_states=True) + + image_features = image_outputs.hidden_states[ + self.config.vision_feature_layer] + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + def _process_image_pixels(self, + inputs: LlavaImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input(self, + image_input: LlavaImageInputs) -> torch.Tensor: + if image_input["type"] == "pixel_values": + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + else: + image_features = image_input["data"] + + return self.multi_modal_projector(image_features) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -144,42 +243,20 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): For PIXEL_VALUES, expecting [1, 3, 336, 336]. For IMAGE_FEATURES, expecting [1, 576, 1024]. """ - if image_input is not None: - if list(image_input.shape[1:]) != list( - self.vision_language_config.image_input_shape[1:]): - raise ValueError( - f"The expected image tensor shape is batch dimension " - f"plus " - f"{self.vision_language_config.image_input_shape[1:]}." - f" You supplied {image_input.shape}. " - f"If you are using vLLM's entrypoint, make sure your " - f"supplied image input is consistent with " - f"image_input_shape in engine args.") - if self.vision_tower is not None: - # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. - image_outputs = self.vision_tower(image_input, - output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.config.vision_feature_layer] - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if self.config.vision_feature_select_strategy == "default": - image_features = image_features[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - image_features = image_features - else: - raise ValueError( - f"Unexpected select feature strategy: " - f"{self.config.vision_feature_select_strategy}") - else: - image_features = image_input - vision_embeddings = self.multi_modal_projector(image_features) + parsed_image_input = self._parse_and_validate_image_input(image_input) + + if parsed_image_input is not None: + vision_embeddings = self._process_image_input(parsed_image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) + input_ids = None else: inputs_embeds = None + hidden_states = self.language_model(input_ids, positions, kv_caches,