mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:35:43 +08:00
474 lines
14 KiB
Python
474 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any, TypedDict
|
|
|
|
import numpy.typing as npt
|
|
import pytest
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from vllm.multimodal.image import rescale_image_size
|
|
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
|
|
|
from ....conftest import (
|
|
IMAGE_ASSETS,
|
|
VIDEO_ASSETS,
|
|
PromptImageInput,
|
|
PromptVideoInput,
|
|
VllmRunner,
|
|
)
|
|
from ...utils import check_logprobs_close
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def enable_pickle(monkeypatch):
|
|
"""`LLM.apply_model` requires pickling a function."""
|
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
|
|
|
|
models = ["Qwen/Qwen2-VL-2B-Instruct"]
|
|
target_dtype = "half"
|
|
|
|
IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
|
|
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
|
|
MODEL_HIDDEN_SIZE = 1536
|
|
|
|
|
|
def qwen2_vl_chat_template(*query):
|
|
return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501
|
|
|
|
|
|
IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
|
|
{
|
|
"stop_sign": qwen2_vl_chat_template(
|
|
IMAGE_PLACEHOLDER,
|
|
"What is the biggest text's content in this image?",
|
|
),
|
|
"cherry_blossom": qwen2_vl_chat_template(
|
|
IMAGE_PLACEHOLDER,
|
|
"What is the season shown in this image? ",
|
|
"Reply with a short sentence (no more than 20 words)",
|
|
),
|
|
}
|
|
)
|
|
|
|
VIDEO_PROMPTS = VIDEO_ASSETS.prompts(
|
|
{
|
|
"baby_reading": qwen2_vl_chat_template(
|
|
VIDEO_PLACEHOLDER,
|
|
"Describe this video with a short sentence ",
|
|
"(no more than 20 words)",
|
|
),
|
|
}
|
|
)
|
|
|
|
MULTIIMAGE_PROMPT = qwen2_vl_chat_template(
|
|
IMAGE_PLACEHOLDER,
|
|
IMAGE_PLACEHOLDER,
|
|
"Describe these two images separately. ",
|
|
"For each image, reply with a short sentence ",
|
|
"(no more than 10 words).",
|
|
)
|
|
|
|
|
|
class Qwen2VLPromptImageEmbeddingInput(TypedDict):
|
|
image_embeds: torch.Tensor
|
|
image_grid_thw: torch.Tensor
|
|
|
|
|
|
class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
|
|
video_embeds: torch.Tensor
|
|
video_grid_thw: torch.Tensor
|
|
|
|
|
|
def batch_make_image_embeddings(
|
|
image_batches: list[Image.Image | list[Image.Image]],
|
|
processor,
|
|
llm: VllmRunner,
|
|
) -> list[Qwen2VLPromptImageEmbeddingInput]:
|
|
"""batched image embeddings for Qwen2-VL
|
|
|
|
This will infer all images' embeddings in a single batch,
|
|
and split the result according to input batches.
|
|
|
|
image_batches:
|
|
- Single-image batches: `list[Image.Image]`
|
|
- Multiple-image batches: `list[list[Image.Image]]]`
|
|
|
|
returns: `list[Qwen2VLPromptImageEmbeddingInput]`
|
|
"""
|
|
|
|
image_batches_: list[Any] = image_batches[:]
|
|
|
|
# convert single-image batches to multiple-image batches
|
|
for idx in range(len(image_batches_)):
|
|
if not isinstance(image_batches_[idx], list):
|
|
image_batches_[idx] = [image_batches_[idx]]
|
|
|
|
assert isinstance(image_batches_[idx], list)
|
|
|
|
# append all images into a list (as a batch)
|
|
images: list[Image.Image] = []
|
|
for image_batch in image_batches_:
|
|
images += image_batch
|
|
|
|
# image to pixel values
|
|
image_processor = processor.image_processor
|
|
|
|
preprocess_result = image_processor.preprocess(
|
|
images=images, return_tensors="pt"
|
|
).data
|
|
pixel_values = preprocess_result["pixel_values"]
|
|
image_grid_thw = preprocess_result["image_grid_thw"]
|
|
|
|
# pixel values to embeddings & grid_thws
|
|
def get_image_embeds(model):
|
|
with torch.no_grad():
|
|
visual = model.visual
|
|
|
|
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
|
return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu()
|
|
|
|
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
|
|
|
# split into original batches
|
|
result: list[Qwen2VLPromptImageEmbeddingInput] = []
|
|
image_counter = 0
|
|
embed_counter = 0
|
|
for image_batch in image_batches_:
|
|
cur_batch_image_count = len(image_batch)
|
|
merge_size = image_processor.merge_size
|
|
cur_batch_embed_len = sum(
|
|
grid_thw.prod(-1) // merge_size // merge_size
|
|
for grid_thw in image_grid_thw[
|
|
image_counter : image_counter + cur_batch_image_count
|
|
]
|
|
)
|
|
|
|
result.append(
|
|
{
|
|
"image_embeds": image_embeds[
|
|
embed_counter : embed_counter + cur_batch_embed_len
|
|
],
|
|
"image_grid_thw": image_grid_thw[
|
|
image_counter : image_counter + cur_batch_image_count
|
|
],
|
|
}
|
|
)
|
|
|
|
embed_counter += cur_batch_embed_len
|
|
image_counter += cur_batch_image_count
|
|
|
|
# ensure we don't lose any images or embeddings
|
|
assert embed_counter == image_embeds.size(0)
|
|
assert image_counter == image_grid_thw.size(0)
|
|
assert len(image_batches) == len(result)
|
|
|
|
return result
|
|
|
|
|
|
def batch_make_video_embeddings(
|
|
video_batches: PromptVideoInput, processor, llm: VllmRunner
|
|
) -> list[Qwen2VLPromptVideoEmbeddingInput]:
|
|
"""batched video embeddings for Qwen2-VL
|
|
|
|
A NDArray represents a single video's all frames.
|
|
|
|
This will infer all videos' embeddings in a single batch,
|
|
and split the result according to input batches.
|
|
|
|
video_batches:
|
|
- Single-video batches: `list[NDArray]`
|
|
- Multiple-video batches: `list[list[NDArray]]`
|
|
"""
|
|
|
|
video_batches_: list[Any] = video_batches[:]
|
|
|
|
for idx in range(len(video_batches_)):
|
|
if not isinstance(video_batches_[idx], list):
|
|
single_video_batch: list[npt.NDArray] = [video_batches_[idx]]
|
|
video_batches_[idx] = single_video_batch
|
|
|
|
assert isinstance(video_batches_[idx], list)
|
|
|
|
# append all videos into a list (as a batch)
|
|
videos: list[npt.NDArray] = []
|
|
for video_batch in video_batches_:
|
|
videos += video_batch
|
|
|
|
# video to pixel values
|
|
image_processor = processor.image_processor
|
|
|
|
preprocess_result = image_processor.preprocess(
|
|
images=None, videos=videos, return_tensors="pt"
|
|
).data
|
|
pixel_values = preprocess_result["pixel_values_videos"]
|
|
video_grid_thw = preprocess_result["video_grid_thw"]
|
|
|
|
# pixel values to embeddings & grid_thws
|
|
def get_image_embeds(model):
|
|
with torch.no_grad():
|
|
visual = model.visual
|
|
|
|
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
|
return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu()
|
|
|
|
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
|
|
|
# split into original batches
|
|
result: list[Qwen2VLPromptVideoEmbeddingInput] = []
|
|
video_counter = 0
|
|
embed_counter = 0
|
|
for video_batch in video_batches_:
|
|
cur_batch_video_count = len(video_batch)
|
|
merge_size = image_processor.merge_size
|
|
cur_batch_embed_len = sum(
|
|
grid_thw.prod(-1) // merge_size // merge_size
|
|
for grid_thw in video_grid_thw[
|
|
video_counter : video_counter + cur_batch_video_count
|
|
]
|
|
)
|
|
|
|
result.append(
|
|
{
|
|
"video_embeds": video_embeds[
|
|
embed_counter : embed_counter + cur_batch_embed_len
|
|
],
|
|
"video_grid_thw": video_grid_thw[
|
|
video_counter : video_counter + cur_batch_video_count
|
|
],
|
|
}
|
|
)
|
|
|
|
embed_counter += cur_batch_embed_len
|
|
video_counter += cur_batch_video_count
|
|
|
|
# ensure we don't lose any videos or embeddings
|
|
assert embed_counter == video_embeds.size(0)
|
|
assert video_counter == video_grid_thw.size(0)
|
|
assert len(video_batches) == len(result)
|
|
|
|
return result
|
|
|
|
|
|
def run_embedding_input_test(
|
|
vllm_runner: type[VllmRunner],
|
|
inputs: list[tuple[list[str], PromptImageInput, PromptVideoInput]],
|
|
model: str,
|
|
*,
|
|
dtype: str,
|
|
max_tokens: int,
|
|
num_logprobs: int,
|
|
mm_limit: int,
|
|
tensor_parallel_size: int,
|
|
distributed_executor_backend: str | None = None,
|
|
):
|
|
"""Inference result should be the same between
|
|
original image/video input and image/video embeddings input.
|
|
"""
|
|
from transformers import AutoProcessor # noqa: F401
|
|
|
|
processor = AutoProcessor.from_pretrained(model)
|
|
|
|
# max_model_len should be greater than image_feature_size
|
|
with vllm_runner(
|
|
model,
|
|
runner="generate",
|
|
max_model_len=4000,
|
|
max_num_seqs=3,
|
|
dtype=dtype,
|
|
limit_mm_per_prompt={"image": mm_limit, "video": mm_limit},
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
distributed_executor_backend=distributed_executor_backend,
|
|
default_torch_num_threads=1,
|
|
enable_mm_embeds=True,
|
|
) as vllm_model:
|
|
outputs_per_case_for_original_input = [
|
|
vllm_model.generate_greedy_logprobs(
|
|
prompts,
|
|
max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
images=images or None,
|
|
videos=videos or None,
|
|
)
|
|
for prompts, images, videos in inputs
|
|
]
|
|
|
|
outputs_per_case_for_embeddings_input = [
|
|
vllm_model.generate_greedy_logprobs(
|
|
prompts,
|
|
max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
images=batch_make_image_embeddings(images, processor, vllm_model)
|
|
if images
|
|
else None,
|
|
videos=batch_make_video_embeddings(videos, processor, vllm_model)
|
|
if videos
|
|
else None,
|
|
)
|
|
for prompts, images, videos in inputs
|
|
]
|
|
|
|
for outputs_for_original_input, outputs_for_embeddings_input in zip(
|
|
outputs_per_case_for_original_input, outputs_per_case_for_embeddings_input
|
|
):
|
|
check_logprobs_close(
|
|
outputs_0_lst=outputs_for_original_input,
|
|
outputs_1_lst=outputs_for_embeddings_input,
|
|
name_0="original_input",
|
|
name_1="embeddings_input",
|
|
)
|
|
|
|
|
|
@pytest.mark.core_model
|
|
@pytest.mark.parametrize("model", models)
|
|
@pytest.mark.parametrize(
|
|
"size_factors",
|
|
[
|
|
# Single-scale
|
|
[0.5],
|
|
# Single-scale, batched
|
|
[0.5, 0.5],
|
|
# Multi-scale
|
|
[0.25, 0.5, 0.5],
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
|
@pytest.mark.parametrize("max_tokens", [128])
|
|
@pytest.mark.parametrize("num_logprobs", [10])
|
|
def test_qwen2_vl_image_embeddings_input(
|
|
vllm_runner,
|
|
image_assets,
|
|
model,
|
|
size_factors,
|
|
dtype,
|
|
max_tokens,
|
|
num_logprobs,
|
|
monkeypatch,
|
|
) -> None:
|
|
images = [asset.pil_image for asset in image_assets]
|
|
|
|
inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [
|
|
(
|
|
[prompt for _ in size_factors],
|
|
[rescale_image_size(image, factor) for factor in size_factors],
|
|
[],
|
|
)
|
|
for image, prompt in zip(images, IMAGE_PROMPTS)
|
|
]
|
|
|
|
run_embedding_input_test(
|
|
vllm_runner,
|
|
inputs_per_case,
|
|
model,
|
|
dtype=dtype,
|
|
max_tokens=max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
mm_limit=1,
|
|
tensor_parallel_size=1,
|
|
)
|
|
|
|
|
|
@pytest.mark.core_model
|
|
@pytest.mark.parametrize("model", models)
|
|
@pytest.mark.parametrize(
|
|
"size_factors",
|
|
[
|
|
[],
|
|
# Single-scale
|
|
[0.5],
|
|
# Single-scale, batched
|
|
[0.5, 0.5],
|
|
# Multi-scale
|
|
[0.25, 0.5, 0.5],
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
|
@pytest.mark.parametrize("max_tokens", [128])
|
|
@pytest.mark.parametrize("num_logprobs", [10])
|
|
def test_qwen2_vl_multiple_image_embeddings_input(
|
|
vllm_runner,
|
|
image_assets,
|
|
model,
|
|
size_factors,
|
|
dtype: str,
|
|
max_tokens: int,
|
|
num_logprobs: int,
|
|
) -> None:
|
|
images = [asset.pil_image for asset in image_assets]
|
|
|
|
inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [
|
|
(
|
|
[MULTIIMAGE_PROMPT for _ in size_factors],
|
|
[
|
|
[rescale_image_size(image, factor) for image in images]
|
|
for factor in size_factors
|
|
],
|
|
[],
|
|
)
|
|
]
|
|
|
|
run_embedding_input_test(
|
|
vllm_runner,
|
|
inputs_per_case,
|
|
model,
|
|
dtype=dtype,
|
|
max_tokens=max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
mm_limit=2,
|
|
tensor_parallel_size=1,
|
|
)
|
|
|
|
|
|
@pytest.mark.core_model
|
|
@pytest.mark.parametrize("model", models)
|
|
@pytest.mark.parametrize(
|
|
"size_factors",
|
|
[
|
|
# Single-scale
|
|
[0.5],
|
|
# Single-scale, batched
|
|
[0.5, 0.5],
|
|
# Multi-scale
|
|
[0.25, 0.25, 0.5],
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
|
@pytest.mark.parametrize("max_tokens", [128])
|
|
@pytest.mark.parametrize("num_logprobs", [10])
|
|
def test_qwen2_vl_video_embeddings_input(
|
|
vllm_runner,
|
|
video_assets,
|
|
model,
|
|
size_factors,
|
|
dtype: str,
|
|
max_tokens: int,
|
|
num_logprobs: int,
|
|
) -> None:
|
|
num_frames = 4
|
|
sampled_vids = [
|
|
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
|
for asset in video_assets
|
|
]
|
|
|
|
inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [
|
|
(
|
|
[prompt for _ in size_factors],
|
|
[],
|
|
[rescale_video_size(video, factor) for factor in size_factors],
|
|
)
|
|
for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)
|
|
]
|
|
|
|
run_embedding_input_test(
|
|
vllm_runner,
|
|
inputs_per_case,
|
|
model,
|
|
dtype=dtype,
|
|
max_tokens=max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
mm_limit=1,
|
|
tensor_parallel_size=1,
|
|
)
|