[CI/Build] Fix OOM issue in Jina-VL test (#20907)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-14 18:32:35 +08:00 committed by GitHub
parent 1e9438e0b0
commit dcf2a5e208
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,15 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import pytest import pytest
from transformers import AutoModel from transformers import AutoModel
from vllm.entrypoints.chat_utils import ChatCompletionContentPartImageParam
from vllm.entrypoints.score_utils import ScoreMultiModalParam
from ....conftest import HfRunner, VllmRunner
model_name = "jinaai/jina-reranker-m0" model_name = "jinaai/jina-reranker-m0"
mm_processor_kwargs = { mm_processor_kwargs = {
@ -14,73 +20,90 @@ mm_processor_kwargs = {
limit_mm_per_prompt = {"image": 2} limit_mm_per_prompt = {"image": 2}
def vllm_reranker(model_name, def vllm_reranker(
query, vllm_runner: type[VllmRunner],
documents, model_name: str,
query_type="text", dtype: str,
doc_type="text"): query_strs: list[str],
from vllm import LLM document_strs: list[str],
query_type: str = "text",
doc_type: str = "text",
):
model = LLM( def create_image_param(url: str) -> ChatCompletionContentPartImageParam:
model=model_name,
task="score",
max_model_len=32768,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,
)
def create_image_param(url: str):
return {"type": "image_url", "image_url": {"url": f"{url}"}} return {"type": "image_url", "image_url": {"url": f"{url}"}}
if query_type == "image": query: Union[list[str], ScoreMultiModalParam]
query = {"content": [create_image_param(url) for url in query]} if query_type == "text":
query = query_strs
elif query_type == "image":
query = ScoreMultiModalParam(
content=[create_image_param(url) for url in query_strs])
if doc_type == "image": documents: Union[list[str], ScoreMultiModalParam]
documents = {"content": [create_image_param(url) for url in documents]} if doc_type == "text":
documents = document_strs
elif doc_type == "image":
documents = ScoreMultiModalParam(
content=[create_image_param(url) for url in document_strs])
outputs = model.score(query, documents) with vllm_runner(
model_name,
task="score",
dtype=dtype,
max_num_seqs=2,
max_model_len=2048,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,
) as vllm_model:
outputs = vllm_model.model.score(query, documents)
return [output.outputs.score for output in outputs] return [output.outputs.score for output in outputs]
def hf_reranker(model_name, def hf_reranker(
query, hf_runner: type[HfRunner],
documents, model_name: str,
query_type="text", dtype: str,
doc_type="text"): query_strs: list[str],
document_strs: list[str],
query_type: str = "text",
doc_type: str = "text",
):
checkpoint_to_hf_mapper = { checkpoint_to_hf_mapper = {
"visual.": "model.visual.", "visual.": "model.visual.",
"model.": "model.language_model.", "model.": "model.language_model.",
} }
model = AutoModel.from_pretrained( data_pairs = [[query_strs[0], d] for d in document_strs]
with hf_runner(
model_name, model_name,
torch_dtype="auto", dtype=dtype,
trust_remote_code=True, trust_remote_code=True,
key_mapping=checkpoint_to_hf_mapper).to("cuda").eval() auto_cls=AutoModel,
model_kwargs={"key_mapping": checkpoint_to_hf_mapper},
data_pairs = [[query[0], d] for d in documents] ) as hf_model:
return hf_model.model.compute_score(data_pairs,
scores = model.compute_score(data_pairs,
max_length=2048, max_length=2048,
query_type=query_type, query_type=query_type,
doc_type=doc_type) doc_type=doc_type)
return scores
# Visual Documents Reranking # Visual Documents Reranking
@pytest.mark.parametrize("model_name", [model_name]) @pytest.mark.parametrize("model_name", [model_name])
def test_model_text_image(model_name): @pytest.mark.parametrize("dtype", ["half"])
def test_model_text_image(hf_runner, vllm_runner, model_name, dtype):
query = ["slm markdown"] query = ["slm markdown"]
documents = [ documents = [
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png", "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
] ]
hf_outputs = hf_reranker(model_name, query, documents, "text", "image") hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
vllm_outputs = vllm_reranker(model_name, query, documents, "text", "image") "text", "image")
vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
documents, "text", "image")
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)
@ -88,8 +111,8 @@ def test_model_text_image(model_name):
# Textual Documents Reranking # Textual Documents Reranking
@pytest.mark.parametrize("model_name", [model_name]) @pytest.mark.parametrize("model_name", [model_name])
def test_model_text_text(model_name): @pytest.mark.parametrize("dtype", ["half"])
def test_model_text_text(hf_runner, vllm_runner, model_name, dtype):
query = ["slm markdown"] query = ["slm markdown"]
documents = [ documents = [
"""We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient """We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient
@ -104,9 +127,10 @@ def test_model_text_text(model_name):
lower computational requirements.""", # noqa: E501 lower computational requirements.""", # noqa: E501
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?", "数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
] ]
hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
hf_outputs = hf_reranker(model_name, query, documents, "text", "text") "text", "text")
vllm_outputs = vllm_reranker(model_name, query, documents, "text", "text") vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
documents, "text", "text")
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)
@ -114,8 +138,8 @@ def test_model_text_text(model_name):
# Image Querying for Textual Documents # Image Querying for Textual Documents
@pytest.mark.parametrize("model_name", [model_name]) @pytest.mark.parametrize("model_name", [model_name])
def test_model_image_text(model_name): @pytest.mark.parametrize("dtype", ["half"])
def test_model_image_text(hf_runner, vllm_runner, model_name, dtype):
query = [ query = [
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
] ]
@ -133,8 +157,10 @@ def test_model_image_text(model_name):
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?", "数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
] ]
hf_outputs = hf_reranker(model_name, query, documents, "image", "text") hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
vllm_outputs = vllm_reranker(model_name, query, documents, "image", "text") "image", "text")
vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
documents, "image", "text")
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)
@ -142,8 +168,8 @@ def test_model_image_text(model_name):
# Image Querying for Image Documents # Image Querying for Image Documents
@pytest.mark.parametrize("model_name", [model_name]) @pytest.mark.parametrize("model_name", [model_name])
def test_model_image_image(model_name): @pytest.mark.parametrize("dtype", ["half"])
def test_model_image_image(hf_runner, vllm_runner, model_name, dtype):
query = [ query = [
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
] ]
@ -152,9 +178,10 @@ def test_model_image_image(model_name):
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
] ]
hf_outputs = hf_reranker(model_name, query, documents, "image", "image") hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
vllm_outputs = vllm_reranker(model_name, query, documents, "image", "image", "image")
"image") vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
documents, "image", "image")
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)