fix test_phi3v (#15321)

Signed-off-by: pansicheng <sicheng.pan.chn@gmail.com>
This commit is contained in:
pansicheng 2025-03-30 17:01:34 +08:00 committed by GitHub
parent 44c3a5abc3
commit 7fd8c0f85c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 110 additions and 14 deletions

View File

@ -3,6 +3,9 @@
import openai
import pytest
import pytest_asyncio
import requests
from PIL import Image
from transformers import AutoProcessor
from vllm.multimodal.utils import encode_image_base64, fetch_image
@ -53,11 +56,31 @@ def base64_encoded_image() -> dict[str, str]:
}
def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(model_name,
trust_remote_code=True,
num_crops=4)
placeholder = "<|image_1|>\n"
messages = [{
"role": "user",
"content": f"{placeholder}{content}",
}]
images = [Image.open(requests.get(image_url, stream=True).raw)]
prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image(client: openai.AsyncOpenAI,
model_name: str, image_url: str):
content_text = "What's in this image?"
messages = [{
"role":
"user",
@ -70,16 +93,17 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
},
{
"type": "text",
"text": "What's in this image?"
"text": content_text
},
],
}]
max_completion_tokens = 10
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
max_completion_tokens=max_completion_tokens,
logprobs=True,
temperature=0.0,
top_logprobs=5)
@ -87,8 +111,12 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=774, total_tokens=784)
completion_tokens=max_completion_tokens,
prompt_tokens=hf_prompt_tokens,
total_tokens=hf_prompt_tokens + max_completion_tokens)
message = choice.message
message = chat_completion.choices[0].message
@ -150,6 +178,7 @@ async def test_single_chat_session_image_base64encoded(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: dict[str, str]):
content_text = "What's in this image?"
messages = [{
"role":
"user",
@ -163,16 +192,17 @@ async def test_single_chat_session_image_base64encoded(
},
{
"type": "text",
"text": "What's in this image?"
"text": content_text
},
],
}]
max_completion_tokens = 10
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
max_completion_tokens=max_completion_tokens,
logprobs=True,
temperature=0.0,
top_logprobs=5)
@ -180,8 +210,12 @@ async def test_single_chat_session_image_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=774, total_tokens=784)
completion_tokens=max_completion_tokens,
prompt_tokens=hf_prompt_tokens,
total_tokens=hf_prompt_tokens + max_completion_tokens)
message = choice.message
message = chat_completion.choices[0].message

View File

@ -2,6 +2,8 @@
import pytest
import requests
from PIL import Image
from transformers import AutoProcessor
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.multimodal.utils import encode_image_base64, fetch_image
@ -52,11 +54,24 @@ def base64_encoded_image() -> dict[str, str]:
}
def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(model_name,
trust_remote_code=True,
num_crops=4)
placeholder = "<|image_1|> "
prompt = f"{placeholder}{content}"
images = [Image.open(requests.get(image_url, stream=True).raw)]
inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
image_url: str):
content_text = "Represent the given image."
messages = [{
"role":
"user",
@ -69,7 +84,7 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
},
{
"type": "text",
"text": "Represent the given image."
"text": content_text
},
],
}]
@ -85,9 +100,12 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
response.raise_for_status()
embeddings = EmbeddingResponse.model_validate(response.json())
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 3072
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 763
assert embeddings.usage.total_tokens == 763
assert embeddings.usage.prompt_tokens == hf_prompt_tokens
assert embeddings.usage.total_tokens == hf_prompt_tokens

View File

@ -2,6 +2,10 @@
import pytest
import torch.nn.functional as F
from PIL import Image
from vllm.assets.base import get_vllm_public_assets
from vllm.assets.image import VLM_IMAGES_DIR
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
@ -112,6 +116,15 @@ def test_models_image(
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
# add cases for special_tokens
input_texts_images.append((
"\n<s><|user|>\n <|image_1|>\n\t <s>"
"Represent the given image for classification<|end|>"
"\n<|assistant|>\n",
Image.open(
get_vllm_public_assets(filename="cherry_blossom.jpg",
s3_prefix=VLM_IMAGES_DIR)),
))
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]

View File

@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
@ -428,10 +429,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
tokenizer = self.info.get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
@ -449,7 +446,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
full=image_tokens,
features=image_tokens,
)
@ -469,6 +466,40 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
# align to hf behavior when there are images
if len(mm_item_counts):
tokenizer = self.info.get_tokenizer()
# to decode token_ids to the original text, we need to
# 1. remove the first bos token
# 2. remove space after each special token
# introduced by the tokenizer
if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
token_ids = token_ids[1:]
text = tokenizer.decode(token_ids)
for special_tokens in tokenizer.special_tokens_map.values():
if isinstance(special_tokens, str):
text = text.replace(f"{special_tokens} ", special_tokens)
elif isinstance(special_tokens, list):
for special_token in special_tokens:
text = text.replace(f"{special_token} ", special_token)
# perform hf behavior
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
pattern = r"<\|image_\d+\|>"
prompt_chunks = [
tokenizer(chunk).input_ids
for chunk in re.split(pattern, text)
]
image_tags = [
tokenizer(chunk, add_special_tokens=False).input_ids
for chunk in re.findall(pattern, text)
]
if len(prompt_chunks) > len(image_tags):
image_tags.append([])
token_ids = [
e for sublist in zip(prompt_chunks, image_tags)
for ele in sublist for e in ele
]
token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids,
mm_prompt_updates=mm_prompt_updates,