mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 03:55:02 +08:00
[Model][VLM] Add multi-video support for LLaVA-Onevision (#8905)
Co-authored-by: litianjian <litianjian@bytedance.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8b0e4f2ad7
commit
5f8d8075f9
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple, Type, overload
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
@ -9,9 +9,8 @@ from vllm.multimodal.utils import (rescale_image_size, rescale_video_size,
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_VideoAssets)
|
||||
from ....utils import large_gpu_test
|
||||
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput,
|
||||
PromptVideoInput, VllmRunner)
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
# Video test
|
||||
@ -20,7 +19,7 @@ HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
|
||||
"<|im_start|>user\n<video>\nwhy is this video funny?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
||||
})
|
||||
|
||||
models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
|
||||
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
@ -47,50 +46,16 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@overload
|
||||
def run_video_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
video_assets: _VideoAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_frames: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def run_video_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
video_assets: _VideoAssets,
|
||||
model: str,
|
||||
*,
|
||||
sizes: List[Tuple[int, int]],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_frames: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
...
|
||||
# Video test
|
||||
_LIMIT_VIDEO_PER_PROMPT = 4
|
||||
|
||||
|
||||
def run_video_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
video_assets: _VideoAssets,
|
||||
inputs: List[Tuple[List[str], PromptVideoInput]],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: Optional[List[float]] = None,
|
||||
sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
@ -99,38 +64,20 @@ def run_video_test(
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
videos = [
|
||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||
for asset in video_assets
|
||||
]
|
||||
|
||||
if size_factors is not None:
|
||||
inputs_per_video = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_video_size(video, factor) for factor in size_factors],
|
||||
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
|
||||
elif sizes is not None:
|
||||
inputs_per_video = [(
|
||||
[prompt for _ in sizes],
|
||||
[resize_video(video, size) for size in sizes],
|
||||
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
|
||||
else:
|
||||
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_model_len=16384,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_video = [
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"video": _LIMIT_VIDEO_PER_PROMPT
|
||||
}) as vllm_model:
|
||||
vllm_outputs_per_input = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
videos=videos)
|
||||
for prompts, videos in inputs_per_video
|
||||
for prompts, videos in inputs
|
||||
]
|
||||
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
@ -142,16 +89,16 @@ def run_video_test(
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_video = [
|
||||
hf_outputs_per_input = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
videos=videos)
|
||||
for prompts, videos in inputs_per_video
|
||||
for prompts, videos in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
|
||||
vllm_outputs_per_video):
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_input,
|
||||
vllm_outputs_per_input):
|
||||
# TODO: Check whether using original CLIPVisionModel can improve
|
||||
# consistency against HF
|
||||
check_logprobs_close(
|
||||
@ -165,74 +112,51 @@ def run_video_test(
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No video
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("num_frames", [16])
|
||||
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
|
||||
dtype, max_tokens, num_logprobs, num_frames) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/videos.
|
||||
For huggingface runner, we provide the np.ndarray as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
run_video_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
video_assets,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
num_frames=num_frames,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("num_frames", [16])
|
||||
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
|
||||
dtype, max_tokens, num_logprobs,
|
||||
def test_models_multiple_video_inputs(hf_runner, vllm_runner, video_assets,
|
||||
model, dtype, max_tokens, num_logprobs,
|
||||
num_frames) -> None:
|
||||
video = sample_frames_from_video(video_assets[0].np_ndarrays, num_frames)
|
||||
inputs = [(
|
||||
[
|
||||
"<|im_start|>user <video><video>\nDescribe 2 videos. \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
"<|im_start|>user <video><video>\nDescribe 2 videos. \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
"<|im_start|>user <video><video><video><video>\nDescribe 4 videos. \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
"<|im_start|>user <video>\nwhy is this video funny? \
|
||||
<|im_end|><|im_start|>assistant\n",
|
||||
],
|
||||
[
|
||||
[video, video],
|
||||
# Images with different sizes and aspect-ratios
|
||||
[
|
||||
rescale_video_size(video, 0.1),
|
||||
video,
|
||||
],
|
||||
[
|
||||
video,
|
||||
rescale_video_size(video, 0.25),
|
||||
resize_video(video, (183, 488)),
|
||||
resize_video(video, (488, 183))
|
||||
],
|
||||
video,
|
||||
])]
|
||||
run_video_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
video_assets,
|
||||
inputs,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
num_frames=num_frames,
|
||||
tensor_parallel_size=1,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
|
||||
@ -303,7 +227,6 @@ def run_image_test(
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
|
||||
@ -88,6 +88,7 @@ def dummy_image_for_clip(
|
||||
def dummy_video_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
num_frames: int,
|
||||
num_videos: int = 1,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
@ -99,7 +100,8 @@ def dummy_video_for_clip(
|
||||
image_height_override=image_height_override)
|
||||
np_frame = np.array(pil_frame["image"])
|
||||
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
|
||||
mm_data = {"video": mm_data_per_video}
|
||||
video_data = [mm_data_per_video] * num_videos
|
||||
mm_data = {"video": video_data}
|
||||
return mm_data
|
||||
|
||||
|
||||
|
||||
@ -43,19 +43,17 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
_MAX_NUM_VIDEOS = 1
|
||||
|
||||
|
||||
class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, num_frames, num_channels, height, width)`
|
||||
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
|
||||
|
||||
Note that `num_frames` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
|
||||
Note that it only supports one video input for one batch.
|
||||
Note that `num_videos` may be different for each batch, and 'num_frames'
|
||||
may be different for each video, in which case the data is passed as a
|
||||
list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
|
||||
@ -213,11 +211,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# TODO: support multiple videos
|
||||
num_videos = mm_counts["video"]
|
||||
if num_videos > _MAX_NUM_VIDEOS:
|
||||
raise NotImplementedError(
|
||||
f"Only {_MAX_NUM_VIDEOS} videos are supported")
|
||||
|
||||
# TODO: support configuring the number of frames
|
||||
num_frames = _MAX_FRAMES_PER_VIDEO
|
||||
@ -232,7 +226,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
image_feature_size_override=video_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
|
||||
mm_data = dummy_video_for_clip(vision_config,
|
||||
num_frames=num_frames,
|
||||
num_videos=num_videos)
|
||||
return seq_data, mm_data
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data = dummy_seq_data_for_siglip(
|
||||
@ -243,7 +239,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
image_feature_size_override=video_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
|
||||
mm_data = dummy_video_for_siglip(vision_config,
|
||||
num_frames=num_frames,
|
||||
num_videos=num_videos)
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
@ -315,7 +313,6 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(video_data, np.ndarray):
|
||||
# Supports both CLIP and Siglip
|
||||
@ -336,10 +333,27 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
elif is_list_of(video_data, np.ndarray):
|
||||
raise NotImplementedError(
|
||||
"Processing multiple videos is not supported")
|
||||
video_feature_size = []
|
||||
for video in video_data:
|
||||
num_frames = video.shape[0]
|
||||
video_feature_size.append(
|
||||
get_llava_onevision_video_tokens(ctx, num_frames))
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
else:
|
||||
raise TypeError(f"Invalid video type: {type(video_data)}")
|
||||
|
||||
msg = f"Unsupported video type: {type(video_data)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@ -723,6 +737,22 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
def _add_image_newline(
|
||||
self,
|
||||
video_features: torch.Tensor,
|
||||
videos: int = 1,
|
||||
frames: int = 1,
|
||||
strategy: str = "one_token",
|
||||
) -> torch.Tensor:
|
||||
if strategy == "one_token":
|
||||
video_features = video_features.reshape(
|
||||
videos, frames * video_features.shape[1], -1)
|
||||
image_newline = self.image_newline[None, None, :].repeat(
|
||||
videos, 1, 1).to(video_features.device)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
return video_features
|
||||
raise ValueError(f"Unexpected video newline strategy: {strategy}")
|
||||
|
||||
def _video_pixels_to_features(
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||
@ -731,9 +761,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
b, num_videos, frames, c, h, w = pixel_values.shape
|
||||
assert (num_videos == _MAX_NUM_VIDEOS)
|
||||
pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
|
||||
video_features = vision_tower(pixel_values)
|
||||
video_features = self._select_image_features(
|
||||
video_features,
|
||||
@ -741,13 +768,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
video_features = self.multi_modal_projector(video_features)
|
||||
video_features = self.apply_pooling(video_features)
|
||||
video_features = video_features.reshape(
|
||||
b, frames * video_features.shape[1], -1)
|
||||
image_newline = self.image_newline[None, None, :].repeat(b, 1, 1).to(
|
||||
video_features.device)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
return video_features
|
||||
|
||||
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
|
||||
@ -755,10 +775,28 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
video_pixels = inputs["data"]
|
||||
|
||||
# TODO: support multiple videos per input
|
||||
if isinstance(video_pixels, torch.Tensor):
|
||||
b, num_videos, frames, c, h, w = video_pixels.shape
|
||||
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
|
||||
stacked_embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, video_pixels)
|
||||
self.vision_tower, pixel_values)
|
||||
stacked_embeddings = self._add_image_newline(stacked_embeddings,
|
||||
videos=b * num_videos,
|
||||
frames=frames,
|
||||
strategy="one_token")
|
||||
return stacked_embeddings
|
||||
elif is_list_of(video_pixels, torch.Tensor):
|
||||
stacked_embeddings = []
|
||||
for video_pixel in video_pixels:
|
||||
num_videos, frames, c, h, w = video_pixel.shape
|
||||
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
|
||||
embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, pixel_values)
|
||||
embeddings = self._add_image_newline(embeddings,
|
||||
videos=num_videos,
|
||||
frames=frames,
|
||||
strategy="one_token")
|
||||
stacked_embeddings.append(embeddings)
|
||||
return stacked_embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@ -93,6 +93,7 @@ def dummy_image_for_siglip(
|
||||
def dummy_video_for_siglip(
|
||||
hf_config: SiglipVisionConfig,
|
||||
num_frames: int,
|
||||
num_videos: int = 1,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
@ -104,7 +105,8 @@ def dummy_video_for_siglip(
|
||||
image_height_override=image_height_override)
|
||||
np_frame = np.array(pil_frame["image"])
|
||||
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
|
||||
mm_data = {"video": mm_data_per_video}
|
||||
video_data = [mm_data_per_video] * num_videos
|
||||
mm_data = {"video": video_data}
|
||||
return mm_data
|
||||
|
||||
|
||||
|
||||
@ -56,15 +56,14 @@ class VideoPlugin(ImagePlugin):
|
||||
) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
# single video input as np.ndarray
|
||||
if isinstance(data, np.ndarray):
|
||||
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
|
||||
video_processor = self._get_hf_video_processor(
|
||||
model_config,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
if video_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
"to process the video object")
|
||||
try:
|
||||
# NOTE: Similar to image; it may be a good idea to filter and
|
||||
# pass mm_processor_kwargs here too, but for now we don't to
|
||||
@ -72,13 +71,10 @@ class VideoPlugin(ImagePlugin):
|
||||
# signatures of the processor don't align
|
||||
batch_data = video_processor(data, return_tensors="pt").data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
logger.error("Failed to process video (%s)", data)
|
||||
raise
|
||||
|
||||
return MultiModalInputs(batch_data)
|
||||
elif is_list_of(data, np.ndarray):
|
||||
raise NotImplementedError(
|
||||
"Multi video for a prompt is not supported yet")
|
||||
|
||||
raise TypeError(f"Invalid video type: {type(data)}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user