mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 08:45:36 +08:00
fix test_phi3v (#15321)
Signed-off-by: pansicheng <sicheng.pan.chn@gmail.com>
This commit is contained in:
parent
44c3a5abc3
commit
7fd8c0f85c
@ -3,6 +3,9 @@
|
|||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
|
|
||||||
@ -53,11 +56,31 @@ def base64_encoded_image() -> dict[str, str]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_prompt_tokens(model_name, content, image_url):
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
num_crops=4)
|
||||||
|
|
||||||
|
placeholder = "<|image_1|>\n"
|
||||||
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{placeholder}{content}",
|
||||||
|
}]
|
||||||
|
images = [Image.open(requests.get(image_url, stream=True).raw)]
|
||||||
|
|
||||||
|
prompt = processor.tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
inputs = processor(prompt, images, return_tensors="pt")
|
||||||
|
|
||||||
|
return inputs.input_ids.shape[1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
||||||
model_name: str, image_url: str):
|
model_name: str, image_url: str):
|
||||||
|
content_text = "What's in this image?"
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
@ -70,16 +93,17 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "What's in this image?"
|
"text": content_text
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}]
|
}]
|
||||||
|
|
||||||
|
max_completion_tokens = 10
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=max_completion_tokens,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
top_logprobs=5)
|
top_logprobs=5)
|
||||||
@ -87,8 +111,12 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
|
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
|
||||||
|
image_url)
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=774, total_tokens=784)
|
completion_tokens=max_completion_tokens,
|
||||||
|
prompt_tokens=hf_prompt_tokens,
|
||||||
|
total_tokens=hf_prompt_tokens + max_completion_tokens)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@ -150,6 +178,7 @@ async def test_single_chat_session_image_base64encoded(
|
|||||||
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
||||||
base64_encoded_image: dict[str, str]):
|
base64_encoded_image: dict[str, str]):
|
||||||
|
|
||||||
|
content_text = "What's in this image?"
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
@ -163,16 +192,17 @@ async def test_single_chat_session_image_base64encoded(
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "What's in this image?"
|
"text": content_text
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}]
|
}]
|
||||||
|
|
||||||
|
max_completion_tokens = 10
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_completion_tokens=10,
|
max_completion_tokens=max_completion_tokens,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
top_logprobs=5)
|
top_logprobs=5)
|
||||||
@ -180,8 +210,12 @@ async def test_single_chat_session_image_base64encoded(
|
|||||||
|
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
|
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
|
||||||
|
image_url)
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=774, total_tokens=784)
|
completion_tokens=max_completion_tokens,
|
||||||
|
prompt_tokens=hf_prompt_tokens,
|
||||||
|
total_tokens=hf_prompt_tokens + max_completion_tokens)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
@ -52,11 +54,24 @@ def base64_encoded_image() -> dict[str, str]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_prompt_tokens(model_name, content, image_url):
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
num_crops=4)
|
||||||
|
|
||||||
|
placeholder = "<|image_1|> "
|
||||||
|
prompt = f"{placeholder}{content}"
|
||||||
|
images = [Image.open(requests.get(image_url, stream=True).raw)]
|
||||||
|
inputs = processor(prompt, images, return_tensors="pt")
|
||||||
|
return inputs.input_ids.shape[1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
||||||
image_url: str):
|
image_url: str):
|
||||||
|
content_text = "Represent the given image."
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
@ -69,7 +84,7 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Represent the given image."
|
"text": content_text
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}]
|
}]
|
||||||
@ -85,9 +100,12 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
embeddings = EmbeddingResponse.model_validate(response.json())
|
embeddings = EmbeddingResponse.model_validate(response.json())
|
||||||
|
|
||||||
|
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
|
||||||
|
image_url)
|
||||||
|
|
||||||
assert embeddings.id is not None
|
assert embeddings.id is not None
|
||||||
assert len(embeddings.data) == 1
|
assert len(embeddings.data) == 1
|
||||||
assert len(embeddings.data[0].embedding) == 3072
|
assert len(embeddings.data[0].embedding) == 3072
|
||||||
assert embeddings.usage.completion_tokens == 0
|
assert embeddings.usage.completion_tokens == 0
|
||||||
assert embeddings.usage.prompt_tokens == 763
|
assert embeddings.usage.prompt_tokens == hf_prompt_tokens
|
||||||
assert embeddings.usage.total_tokens == 763
|
assert embeddings.usage.total_tokens == hf_prompt_tokens
|
||||||
|
|||||||
@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.assets.base import get_vllm_public_assets
|
||||||
|
from vllm.assets.image import VLM_IMAGES_DIR
|
||||||
|
|
||||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||||
from ....utils import large_gpu_test
|
from ....utils import large_gpu_test
|
||||||
@ -112,6 +116,15 @@ def test_models_image(
|
|||||||
(text, asset.pil_image)
|
(text, asset.pil_image)
|
||||||
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
|
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
|
||||||
]
|
]
|
||||||
|
# add cases for special_tokens
|
||||||
|
input_texts_images.append((
|
||||||
|
"\n<s><|user|>\n <|image_1|>\n\t <s>"
|
||||||
|
"Represent the given image for classification<|end|>"
|
||||||
|
"\n<|assistant|>\n",
|
||||||
|
Image.open(
|
||||||
|
get_vllm_public_assets(filename="cherry_blossom.jpg",
|
||||||
|
s3_prefix=VLM_IMAGES_DIR)),
|
||||||
|
))
|
||||||
input_texts = [text for text, _ in input_texts_images]
|
input_texts = [text for text, _ in input_texts_images]
|
||||||
input_images = [image for _, image in input_texts_images]
|
input_images = [image for _, image in input_texts_images]
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import re
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
@ -428,10 +429,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||||
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
|
||||||
bos_token_id = tokenizer.bos_token_id
|
|
||||||
assert isinstance(bos_token_id, int)
|
|
||||||
|
|
||||||
def get_replacement_phi3v(item_idx: int):
|
def get_replacement_phi3v(item_idx: int):
|
||||||
images = mm_items.get_items(
|
images = mm_items.get_items(
|
||||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||||
@ -449,7 +446,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|||||||
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
|
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
|
||||||
|
|
||||||
return PromptUpdateDetails(
|
return PromptUpdateDetails(
|
||||||
full=image_tokens + [bos_token_id],
|
full=image_tokens,
|
||||||
features=image_tokens,
|
features=image_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -469,6 +466,40 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||||
|
# align to hf behavior when there are images
|
||||||
|
if len(mm_item_counts):
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
# to decode token_ids to the original text, we need to
|
||||||
|
# 1. remove the first bos token
|
||||||
|
# 2. remove space after each special token
|
||||||
|
# introduced by the tokenizer
|
||||||
|
if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
|
||||||
|
token_ids = token_ids[1:]
|
||||||
|
text = tokenizer.decode(token_ids)
|
||||||
|
for special_tokens in tokenizer.special_tokens_map.values():
|
||||||
|
if isinstance(special_tokens, str):
|
||||||
|
text = text.replace(f"{special_tokens} ", special_tokens)
|
||||||
|
elif isinstance(special_tokens, list):
|
||||||
|
for special_token in special_tokens:
|
||||||
|
text = text.replace(f"{special_token} ", special_token)
|
||||||
|
# perform hf behavior
|
||||||
|
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
|
||||||
|
pattern = r"<\|image_\d+\|>"
|
||||||
|
prompt_chunks = [
|
||||||
|
tokenizer(chunk).input_ids
|
||||||
|
for chunk in re.split(pattern, text)
|
||||||
|
]
|
||||||
|
image_tags = [
|
||||||
|
tokenizer(chunk, add_special_tokens=False).input_ids
|
||||||
|
for chunk in re.findall(pattern, text)
|
||||||
|
]
|
||||||
|
if len(prompt_chunks) > len(image_tags):
|
||||||
|
image_tags.append([])
|
||||||
|
token_ids = [
|
||||||
|
e for sublist in zip(prompt_chunks, image_tags)
|
||||||
|
for ele in sublist for e in ele
|
||||||
|
]
|
||||||
|
|
||||||
token_ids, text, placeholders = super()._apply_prompt_updates(
|
token_ids, text, placeholders = super()._apply_prompt_updates(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
mm_prompt_updates=mm_prompt_updates,
|
mm_prompt_updates=mm_prompt_updates,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user