mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Model] Expose Phi3v num_crops as a mm_processor_kwarg (#8658)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3f06bae907
commit
8ff7ced996
@ -83,10 +83,24 @@ def run_phi3v(question, modality):
|
||||
|
||||
# In this example, we override max_num_seqs to 5 while
|
||||
# keeping the original context length of 128k.
|
||||
|
||||
# num_crops is an override kwarg to the multimodal image processor;
|
||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||
# to use 16 for single frame scenarios, and 4 for multi-frame.
|
||||
#
|
||||
# Generally speaking, a larger value for num_crops results in more
|
||||
# tokens per image instance, because it may scale the image more in
|
||||
# the image preprocessing. Some references in the model docs and the
|
||||
# formula for image tokens after the preprocessing
|
||||
# transform can be found below.
|
||||
#
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3-vision-128k-instruct",
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
mm_processor_kwargs={"num_crops": 16},
|
||||
)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -67,11 +67,24 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
|
||||
|
||||
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
# num_crops is an override kwarg to the multimodal image processor;
|
||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||
# to use 16 for single frame scenarios, and 4 for multi-frame.
|
||||
#
|
||||
# Generally speaking, a larger value for num_crops results in more
|
||||
# tokens per image instance, because it may scale the image more in
|
||||
# the image preprocessing. Some references in the model docs and the
|
||||
# formula for image tokens after the preprocessing
|
||||
# transform can be found below.
|
||||
#
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
)
|
||||
placeholders = "\n".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
|
||||
@ -1,16 +1,21 @@
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import Callable, List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputContext, LLMInputs
|
||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from ...utils import check_logprobs_close
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
from ...utils import build_model_context, check_logprobs_close
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
@ -71,7 +76,7 @@ def run_test(
|
||||
|
||||
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
### Fast tests for correctness in processor_kwarg override handling
|
||||
|
||||
|
||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
def input_processor_for_phi3v():
|
||||
from vllm.model_executor.models.phi3v import input_processor_for_phi3v
|
||||
return input_processor_for_phi3v
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_data_for_phi3v():
|
||||
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
|
||||
return dummy_data_for_phi3v
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_max_phi3v_image_tokens():
|
||||
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
|
||||
return get_max_phi3v_image_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("num_crops", [4, 16, None])
|
||||
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
|
||||
num_crops: Optional[int]):
|
||||
"""Ensure that the [default] input mapper handles num_crops properly."""
|
||||
# We pass the processor kwargs here since for this model, we fall back to
|
||||
# the default mapper; this will fall back to the HF mapper and forward
|
||||
# mm_processor_kwargs to it.
|
||||
mm_processor_kwargs = {
|
||||
"num_crops": num_crops
|
||||
} if num_crops is not None else {}
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
hf_processor = AutoImageProcessor.from_pretrained(model,
|
||||
trust_remote_code=True,
|
||||
**mm_processor_kwargs)
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
vllm_result = mm_registry.map_input(
|
||||
ctx.model_config,
|
||||
{"image": image},
|
||||
)
|
||||
|
||||
assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
|
||||
assert torch.all(
|
||||
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])
|
||||
|
||||
# For pixel values, the second axis should be the num_crops + 1
|
||||
# for the rescaled original image. The default value in VLLM falls
|
||||
# back to the HF config, which is why we compare to the processor num_crops
|
||||
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
|
||||
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
|
||||
(4, 781),
|
||||
(16, 2653),
|
||||
])
|
||||
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
|
||||
num_crops: int, expected_max_tokens: int):
|
||||
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
|
||||
# NOTE: mm_processor_kwargs on the context in this test is unused, since
|
||||
# this is testing the mapper directly. In practice, the processor kwargs
|
||||
# are wrapped in a closure when calling the max tokens func. We explicitly
|
||||
# do NOT use the mm_processor_kwargs in the model context here to ensure
|
||||
# that the max image tokens implementation is referencing a mix of the
|
||||
# kwargs to the function and the original mm_processor_kwargs in case
|
||||
# values are somehow updated and end up in a bad state.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
actual_max_tokens = get_max_phi3v_image_tokens(
|
||||
InputContext(ctx.model_config),
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
assert expected_max_tokens == actual_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [
|
||||
(4, 781, 1),
|
||||
(4, 781, 2),
|
||||
(16, 2653, 1),
|
||||
(16, 2653, 2),
|
||||
])
|
||||
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
|
||||
num_crops: int, toks_per_img: int, num_imgs: int):
|
||||
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the dummy data func.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
sequence_data, _, = dummy_data_for_phi3v(
|
||||
ctx=ctx,
|
||||
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
|
||||
mm_counts={"image": num_imgs},
|
||||
num_crops=num_crops,
|
||||
)
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
|
||||
assert img_tok_count == toks_per_img * num_imgs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
|
||||
(4, 757, 1),
|
||||
(4, 757, 2),
|
||||
(16, 1921, 1),
|
||||
(16, 1921, 2),
|
||||
])
|
||||
def test_input_processor_override(input_processor_for_phi3v: Callable,
|
||||
image_assets: _ImageAssets, model: str,
|
||||
num_crops: int, expected_toks_per_img: int,
|
||||
num_imgs: int):
|
||||
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
|
||||
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
||||
images = [image_assets[0].pil_image] * num_imgs
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": images})
|
||||
|
||||
proc_llm_inputs = input_processor_for_phi3v(
|
||||
ctx=ctx,
|
||||
llm_inputs=llm_inputs,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
|
||||
@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
|
||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
|
||||
transposed = False
|
||||
if width < height:
|
||||
width, height = height, width
|
||||
@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
num_crops: int,
|
||||
) -> int:
|
||||
num_crops = hf_config.get("num_crops", 16)
|
||||
if num_crops is None:
|
||||
num_crops = hf_config.get("num_crops", 16)
|
||||
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
||||
height=input_height,
|
||||
hd_num=num_crops)
|
||||
@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
|
||||
+ (new_height // 336 + 1) * 12
|
||||
|
||||
|
||||
def get_max_phi3v_image_tokens(ctx: InputContext):
|
||||
def get_max_phi3v_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
|
||||
return get_phi3v_image_feature_size(
|
||||
ctx.get_hf_image_processor_config(),
|
||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
def dummy_data_for_phi3v(ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_phi3v_image_tokens(ctx)
|
||||
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
|
||||
return image_placeholder_token_ids
|
||||
|
||||
|
||||
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
def input_processor_for_phi3v(ctx: InputContext,
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
image_feature_size = [
|
||||
get_phi3v_image_feature_size(hf_config,
|
||||
input_width=w,
|
||||
input_height=h)
|
||||
input_height=h,
|
||||
num_crops=num_crops)
|
||||
]
|
||||
image_data = [image_data]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
image_feature_size.append(
|
||||
get_phi3v_image_feature_size(hf_config,
|
||||
input_width=w,
|
||||
input_height=h))
|
||||
input_height=h,
|
||||
num_crops=num_crops))
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user