mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 06:15:41 +08:00
[Doc] [2/N] Add Fuyu E2E example for multimodal processor (#13331)
This commit is contained in:
parent
54ed913f34
commit
367cb8ce8c
@ -262,6 +262,255 @@ def get_mm_max_tokens_per_item(
|
|||||||
Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP.
|
Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
::::
|
||||||
|
|
||||||
|
::::{tab-item} Non-consecutive feature tokens: Fuyu
|
||||||
|
:sync: fuyu
|
||||||
|
|
||||||
|
Looking at the code of HF's `FuyuForCausalLM`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322
|
||||||
|
if image_patches is not None and past_key_values is None:
|
||||||
|
patch_embeddings = [
|
||||||
|
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
|
||||||
|
.squeeze(0)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
for patch in image_patches
|
||||||
|
]
|
||||||
|
inputs_embeds = self.gather_continuous_embeddings(
|
||||||
|
word_embeddings=inputs_embeds,
|
||||||
|
continuous_embeddings=patch_embeddings,
|
||||||
|
image_patch_input_indices=image_patches_indices,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The number of placeholder feature tokens for the `i`th item in the batch is `patch_embeddings[i].shape[0]`,
|
||||||
|
which is the same as `image_patches[i].shape[0]`, i.e. `num_total_patches`.
|
||||||
|
|
||||||
|
Unlike LLaVA, Fuyu does not define the number of patches inside the modeling file. Where can we get more information?
|
||||||
|
Considering that the model input comes from the output of `FuyuProcessor`, let's **look at the preprocessing files**.
|
||||||
|
|
||||||
|
The image outputs are obtained by calling `FuyuImageProcessor.preprocess` and then
|
||||||
|
`FuyuImageProcessor.preprocess_with_tokenizer_info` inside `FuyuProcessor`.
|
||||||
|
|
||||||
|
In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`,
|
||||||
|
returning the dimensions after resizing (but before padding) as metadata.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544
|
||||||
|
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
|
||||||
|
batch_images = image_encoding["images"]
|
||||||
|
image_unpadded_heights = image_encoding["image_unpadded_heights"]
|
||||||
|
image_unpadded_widths = image_encoding["image_unpadded_widths"]
|
||||||
|
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L
|
||||||
|
if do_resize:
|
||||||
|
batch_images = [
|
||||||
|
[self.resize(image, size=size, input_data_format=input_data_format) for image in images]
|
||||||
|
for images in batch_images
|
||||||
|
]
|
||||||
|
|
||||||
|
image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
|
||||||
|
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
|
||||||
|
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
|
||||||
|
|
||||||
|
if do_pad:
|
||||||
|
batch_images = [
|
||||||
|
[
|
||||||
|
self.pad_image(
|
||||||
|
image,
|
||||||
|
size=size,
|
||||||
|
mode=padding_mode,
|
||||||
|
constant_values=padding_value,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in batch_images
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425
|
||||||
|
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||||
|
image_input=tensor_batch_images,
|
||||||
|
image_present=image_present,
|
||||||
|
image_unpadded_h=image_unpadded_heights,
|
||||||
|
image_unpadded_w=image_unpadded_widths,
|
||||||
|
image_placeholder_id=image_placeholder_id,
|
||||||
|
image_newline_id=image_newline_id,
|
||||||
|
variable_sized=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658
|
||||||
|
image_height, image_width = image.shape[1], image.shape[2]
|
||||||
|
if variable_sized: # variable_sized=True
|
||||||
|
new_h = min(
|
||||||
|
image_height,
|
||||||
|
math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
|
||||||
|
)
|
||||||
|
new_w = min(
|
||||||
|
image_width,
|
||||||
|
math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
|
||||||
|
)
|
||||||
|
image = image[:, :new_h, :new_w]
|
||||||
|
image_height, image_width = new_h, new_w
|
||||||
|
|
||||||
|
num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
|
||||||
|
tensor_of_image_ids = torch.full(
|
||||||
|
[num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
|
||||||
|
)
|
||||||
|
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
|
||||||
|
assert num_patches == patches.shape[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562
|
||||||
|
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||||
|
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||||
|
|
||||||
|
if image_height % patch_height != 0:
|
||||||
|
raise ValueError(f"{image_height=} must be divisible by {patch_height}")
|
||||||
|
if image_width % patch_width != 0:
|
||||||
|
raise ValueError(f"{image_width=} must be divisible by {patch_width}")
|
||||||
|
|
||||||
|
num_patches_per_dim_h = image_height // patch_height
|
||||||
|
num_patches_per_dim_w = image_width // patch_width
|
||||||
|
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
|
||||||
|
```
|
||||||
|
|
||||||
|
We can calculate this in vLLM using this code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_num_image_patches(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> int:
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
target_width = image_processor.size["width"]
|
||||||
|
target_height = image_processor.size["height"]
|
||||||
|
patch_width = image_processor.patch_size["width"]
|
||||||
|
patch_height = image_processor.patch_size["height"]
|
||||||
|
|
||||||
|
if not (image_width <= target_width and image_height <= target_height):
|
||||||
|
height_scale_factor = target_height / image_height
|
||||||
|
width_scale_factor = target_width / image_width
|
||||||
|
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||||
|
|
||||||
|
image_height = int(image_height * optimal_scale_factor)
|
||||||
|
image_width = int(image_width * optimal_scale_factor)
|
||||||
|
|
||||||
|
ncols = math.ceil(image_width / patch_width)
|
||||||
|
nrows = math.ceil(image_height / patch_height)
|
||||||
|
return ncols * nrows
|
||||||
|
```
|
||||||
|
|
||||||
|
These image patches correspond to placeholder tokens (`|SPEAKER|`). However, the processor also
|
||||||
|
inserts newline tokens (`|NEWLINE|`) as shown here:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L654-L670
|
||||||
|
tensor_of_image_ids = torch.full(
|
||||||
|
[num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
|
||||||
|
)
|
||||||
|
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
|
||||||
|
assert num_patches == patches.shape[0]
|
||||||
|
|
||||||
|
if variable_sized:
|
||||||
|
# Now terminate each line with |NEWLINE|.
|
||||||
|
tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width)
|
||||||
|
newline_ids = torch.full(
|
||||||
|
[tensor_of_image_ids.shape[0], 1],
|
||||||
|
image_newline_id,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=image_input.device,
|
||||||
|
)
|
||||||
|
tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1)
|
||||||
|
tensor_of_image_ids = tensor_of_image_ids.reshape(-1)
|
||||||
|
```
|
||||||
|
|
||||||
|
So, the layout of tokens for an image is:
|
||||||
|
|
||||||
|
```
|
||||||
|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|
||||||
|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|
||||||
|
...
|
||||||
|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|
||||||
|
```
|
||||||
|
|
||||||
|
This makes the placeholder tokens non-consecutive in the prompt.
|
||||||
|
Since vLLM requires the feature tokens to be consecutive, **we also treat the newline tokens as feature tokens**.
|
||||||
|
|
||||||
|
So overall, the total number of feature tokens is
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> int:
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
target_width = image_processor.size["width"]
|
||||||
|
target_height = image_processor.size["height"]
|
||||||
|
patch_width = image_processor.patch_size["width"]
|
||||||
|
patch_height = image_processor.patch_size["height"]
|
||||||
|
|
||||||
|
if not (image_width <= target_width and image_height <= target_height):
|
||||||
|
height_scale_factor = target_height / image_height
|
||||||
|
width_scale_factor = target_width / image_width
|
||||||
|
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||||
|
|
||||||
|
image_height = int(image_height * optimal_scale_factor)
|
||||||
|
image_width = int(image_width * optimal_scale_factor)
|
||||||
|
|
||||||
|
ncols = math.ceil(image_width / patch_width)
|
||||||
|
nrows = math.ceil(image_height / patch_height)
|
||||||
|
return (ncols + 1) * nrows
|
||||||
|
```
|
||||||
|
|
||||||
|
To calculate the maximum number of image tokens, recall that input images are first resized
|
||||||
|
to fit within `image_processor.size`. The maximum possible dimensions of the image before
|
||||||
|
being converted into patches is therefore equal to `image_processor.size`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
return ImageSize(width=image_processor.size["width"],
|
||||||
|
height=image_processor.size["height"])
|
||||||
|
|
||||||
|
def get_max_image_tokens(self) -> int:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
return self.get_num_image_tokens(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
And thus, we can override the method as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
return {"image": self.get_max_image_tokens()}
|
||||||
|
```
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) returns `ncols` and `nrows` directly instead of the total token count.
|
||||||
|
This is because `ncols` and `nrows` are used to specify the layout of the feature tokens (as shown in Step 4 of this guide).
|
||||||
|
:::
|
||||||
|
|
||||||
::::
|
::::
|
||||||
:::::
|
:::::
|
||||||
|
|
||||||
@ -282,7 +531,8 @@ on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max
|
|||||||
::::{tab-set}
|
::::{tab-set}
|
||||||
:::{tab-item} Basic example: LLaVA
|
:::{tab-item} Basic example: LLaVA
|
||||||
:sync: llava
|
:sync: llava
|
||||||
Making use of the `get_image_size_with_most_features` method implemented in the previous section:
|
|
||||||
|
Making use of the `get_image_size_with_most_features` method implemented in Step 2:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
@ -312,6 +562,39 @@ def get_dummy_processor_inputs(
|
|||||||
```
|
```
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} No input placeholders: Fuyu
|
||||||
|
:sync: fuyu
|
||||||
|
|
||||||
|
Fuyu does not expect image placeholders in the inputs to HF processor, so
|
||||||
|
the dummy prompt text is empty regardless of the number of images.
|
||||||
|
Otherwise, the logic of this method is very similar to LLaVA:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text="",
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
::::
|
::::
|
||||||
|
|
||||||
## 4. Specify processing details
|
## 4. Specify processing details
|
||||||
@ -325,40 +608,28 @@ to fill in the missing details about HF processing.
|
|||||||
|
|
||||||
### Multi-modal fields
|
### Multi-modal fields
|
||||||
|
|
||||||
Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to
|
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to
|
||||||
return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items.
|
return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items.
|
||||||
|
|
||||||
:::::{tab-set}
|
:::::{tab-set}
|
||||||
::::{tab-item} Basic example: LLaVA
|
::::{tab-item} Basic example: LLaVA
|
||||||
:sync: llava
|
:sync: llava
|
||||||
|
|
||||||
Looking at the model's `forward` method:
|
The output of `CLIPImageProcessor` is a simple tensor with shape
|
||||||
|
`(num_images, num_channels, image_height, image_width)`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L387-L404
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/image_processing_clip.py#L339-L345
|
||||||
def forward(
|
images = [
|
||||||
self,
|
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||||
input_ids: torch.LongTensor = None,
|
for image in all_images
|
||||||
pixel_values: torch.FloatTensor = None,
|
]
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
data = {"pixel_values": images}
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[int] = None,
|
|
||||||
vision_feature_select_strategy: Optional[str] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
num_logits_to_keep: int = 0,
|
|
||||||
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The only related keyword argument is `pixel_values` which directly corresponds to input images.
|
So, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows:
|
||||||
The shape of `pixel_values` is `(N, C, H, W)` where `N` is the number of images.
|
|
||||||
So, we override the method as follows:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
@ -377,11 +648,83 @@ pre-computed image embeddings, which can be passed to be model via the `image_em
|
|||||||
:::
|
:::
|
||||||
|
|
||||||
::::
|
::::
|
||||||
|
|
||||||
|
::::{tab-item} With postprocessing: Fuyu
|
||||||
|
:sync: fuyu
|
||||||
|
|
||||||
|
The `image_patches` output of `FuyuImageProcessor.preprocess_with_tokenizer_info` concatenates
|
||||||
|
the patches from each image belonging to an item in the batch:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L673-L679
|
||||||
|
image_input_ids.append(tensor_of_image_ids)
|
||||||
|
image_patches.append(patches)
|
||||||
|
else:
|
||||||
|
image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))
|
||||||
|
|
||||||
|
batch_image_input_ids.append(image_input_ids)
|
||||||
|
batch_image_patches.append(image_patches)
|
||||||
|
```
|
||||||
|
|
||||||
|
The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore
|
||||||
|
`(1, num_images, num_patches, patch_width * patch_height * num_channels)`.
|
||||||
|
|
||||||
|
In order to support the use of {func}`MultiModalFieldConfig.batched` like in LLaVA,
|
||||||
|
we remove the extra batch dimension by overriding {meth}`BaseMultiModalProcessor._call_hf_processor`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
processed_outputs = super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_patches = processed_outputs.get("image_patches")
|
||||||
|
if image_patches is not None:
|
||||||
|
images = mm_data["images"]
|
||||||
|
assert isinstance(images, list)
|
||||||
|
|
||||||
|
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||||
|
# New output: (num_images, Pn, Px * Py * C)
|
||||||
|
assert (isinstance(image_patches, list)
|
||||||
|
and len(image_patches) == 1)
|
||||||
|
assert (isinstance(image_patches[0], torch.Tensor)
|
||||||
|
and len(image_patches[0]) == len(images))
|
||||||
|
|
||||||
|
processed_outputs["image_patches"] = image_patches[0]
|
||||||
|
|
||||||
|
return processed_outputs
|
||||||
|
```
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling
|
||||||
|
for text-only inputs to prevent unnecessary warnings from HF processor.
|
||||||
|
:::
|
||||||
|
|
||||||
|
This lets us override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||||
|
```
|
||||||
|
|
||||||
|
::::
|
||||||
|
|
||||||
:::::
|
:::::
|
||||||
|
|
||||||
### Prompt replacements
|
### Prompt replacements
|
||||||
|
|
||||||
Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
|
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
|
||||||
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
|
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
|
||||||
|
|
||||||
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
|
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
|
||||||
@ -402,7 +745,7 @@ for sample in text:
|
|||||||
```
|
```
|
||||||
|
|
||||||
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
||||||
Based on this, we override the method as follows:
|
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def _get_prompt_replacements(
|
def _get_prompt_replacements(
|
||||||
@ -435,6 +778,159 @@ def _get_prompt_replacements(
|
|||||||
```
|
```
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Handling additional tokens: Fuyu
|
||||||
|
:sync: fuyu
|
||||||
|
|
||||||
|
Recall the layout of feature tokens from Step 2:
|
||||||
|
|
||||||
|
```
|
||||||
|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|
||||||
|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|
||||||
|
...
|
||||||
|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|
||||||
|
```
|
||||||
|
|
||||||
|
We define a helper function to return `ncols` and `nrows` directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_image_feature_grid_size(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
target_width = image_processor.size["width"]
|
||||||
|
target_height = image_processor.size["height"]
|
||||||
|
patch_width = image_processor.patch_size["width"]
|
||||||
|
patch_height = image_processor.patch_size["height"]
|
||||||
|
|
||||||
|
if not (image_width <= target_width and image_height <= target_height):
|
||||||
|
height_scale_factor = target_height / image_height
|
||||||
|
width_scale_factor = target_width / image_width
|
||||||
|
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||||
|
|
||||||
|
image_height = int(image_height * optimal_scale_factor)
|
||||||
|
image_width = int(image_width * optimal_scale_factor)
|
||||||
|
|
||||||
|
ncols = math.ceil(image_width / patch_width)
|
||||||
|
nrows = math.ceil(image_height / patch_height)
|
||||||
|
return ncols, nrows
|
||||||
|
```
|
||||||
|
|
||||||
|
Based on this, we can initially define our replacement tokens as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_replacement(item_idx: int):
|
||||||
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
image_size = images.get_image_size(item_idx)
|
||||||
|
|
||||||
|
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||||
|
image_width=image_size.width,
|
||||||
|
image_height=image_size.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
# `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|`
|
||||||
|
# `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|`
|
||||||
|
return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
|
||||||
|
```
|
||||||
|
|
||||||
|
However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called,
|
||||||
|
a BOS token (`<s>`) is also added to the promopt:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435
|
||||||
|
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||||
|
image_input=tensor_batch_images,
|
||||||
|
image_present=image_present,
|
||||||
|
image_unpadded_h=image_unpadded_heights,
|
||||||
|
image_unpadded_w=image_unpadded_widths,
|
||||||
|
image_placeholder_id=image_placeholder_id,
|
||||||
|
image_newline_id=image_newline_id,
|
||||||
|
variable_sized=True,
|
||||||
|
)
|
||||||
|
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
prompts=prompts,
|
||||||
|
scale_factors=scale_factors,
|
||||||
|
max_tokens_to_generate=self.max_tokens_to_generate,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
add_BOS=True,
|
||||||
|
add_beginning_of_answer_token=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails`
|
||||||
|
with different `full` and `feature` attributes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
bos_token_id = hf_config.bos_token_id # `<s>`
|
||||||
|
assert isinstance(bos_token_id, int)
|
||||||
|
|
||||||
|
def get_replacement_fuyu(item_idx: int):
|
||||||
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
image_size = images.get_image_size(item_idx)
|
||||||
|
|
||||||
|
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||||
|
image_width=image_size.width,
|
||||||
|
image_height=image_size.height,
|
||||||
|
)
|
||||||
|
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||||
|
[_NEWLINE_TOKEN_ID]) * nrows
|
||||||
|
|
||||||
|
return PromptReplacementDetails(
|
||||||
|
full=image_tokens + [bos_token_id],
|
||||||
|
features=image_tokens,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt,
|
||||||
|
we can search for it to conduct the replacement at the start of the string:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _get_prompt_replacements(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
bos_token_id = hf_config.bos_token_id
|
||||||
|
assert isinstance(bos_token_id, int)
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
eot_token_id = tokenizer.bos_token_id
|
||||||
|
assert isinstance(eot_token_id, int)
|
||||||
|
|
||||||
|
def get_replacement_fuyu(item_idx: int):
|
||||||
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
image_size = images.get_image_size(item_idx)
|
||||||
|
|
||||||
|
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||||
|
image_width=image_size.width,
|
||||||
|
image_height=image_size.height,
|
||||||
|
)
|
||||||
|
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||||
|
[_NEWLINE_TOKEN_ID]) * nrows
|
||||||
|
|
||||||
|
return PromptReplacementDetails(
|
||||||
|
full=image_tokens + [bos_token_id],
|
||||||
|
features=image_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=[eot_token_id],
|
||||||
|
replacement=get_replacement_fuyu,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
::::
|
::::
|
||||||
|
|
||||||
## 5. Register processor-related classes
|
## 5. Register processor-related classes
|
||||||
|
|||||||
@ -104,6 +104,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
|||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
target_width = image_processor.size["width"]
|
target_width = image_processor.size["width"]
|
||||||
target_height = image_processor.size["height"]
|
target_height = image_processor.size["height"]
|
||||||
|
patch_width = image_processor.patch_size["width"]
|
||||||
|
patch_height = image_processor.patch_size["height"]
|
||||||
|
|
||||||
if not (image_width <= target_width and image_height <= target_height):
|
if not (image_width <= target_width and image_height <= target_height):
|
||||||
height_scale_factor = target_height / image_height
|
height_scale_factor = target_height / image_height
|
||||||
@ -113,8 +115,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
|||||||
image_height = int(image_height * optimal_scale_factor)
|
image_height = int(image_height * optimal_scale_factor)
|
||||||
image_width = int(image_width * optimal_scale_factor)
|
image_width = int(image_width * optimal_scale_factor)
|
||||||
|
|
||||||
ncols = math.ceil(image_width / 30)
|
ncols = math.ceil(image_width / patch_width)
|
||||||
nrows = math.ceil(image_height / 30)
|
nrows = math.ceil(image_height / patch_height)
|
||||||
return ncols, nrows
|
return ncols, nrows
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user