[Model][VLM] Add multi-video support for LLaVA-Onevision (#8905)

Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
litianjian 2024-10-29 02:04:10 +08:00 committed by GitHub
parent 8b0e4f2ad7
commit 5f8d8075f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 162 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Type, overload
from typing import List, Optional, Tuple, Type
import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
@ -9,9 +9,8 @@ from vllm.multimodal.utils import (rescale_image_size, rescale_video_size,
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_VideoAssets)
from ....utils import large_gpu_test
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput,
PromptVideoInput, VllmRunner)
from ...utils import check_logprobs_close
# Video test
@ -20,7 +19,7 @@ HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"<|im_start|>user\n<video>\nwhy is this video funny?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
})
models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@ -47,50 +46,16 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
# Video test
_LIMIT_VIDEO_PER_PROMPT = 4
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
inputs: List[Tuple[List[str], PromptVideoInput]],
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
@ -99,38 +64,20 @@ def run_video_test(
distributed_executor_backend: Optional[str] = None,
):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
videos = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]
if size_factors is not None:
inputs_per_video = [(
[prompt for _ in size_factors],
[rescale_video_size(video, factor) for factor in size_factors],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
elif sizes is not None:
inputs_per_video = [(
[prompt for _ in sizes],
[resize_video(video, size) for size in sizes],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_model_len=16384,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_video = [
enforce_eager=True,
limit_mm_per_prompt={"video": _LIMIT_VIDEO_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_input = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
for prompts, videos in inputs
]
def process(hf_inputs: BatchEncoding):
@ -142,16 +89,16 @@ def run_video_test(
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_video = [
hf_outputs_per_input = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
for prompts, videos in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
vllm_outputs_per_video):
for hf_outputs, vllm_outputs in zip(hf_outputs_per_input,
vllm_outputs_per_input):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
@ -165,74 +112,51 @@ def run_video_test(
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No video
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
dtype, max_tokens, num_logprobs, num_frames) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/videos.
For huggingface runner, we provide the np.ndarray as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
run_video_test(
hf_runner,
vllm_runner,
video_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
dtype, max_tokens, num_logprobs,
def test_models_multiple_video_inputs(hf_runner, vllm_runner, video_assets,
model, dtype, max_tokens, num_logprobs,
num_frames) -> None:
video = sample_frames_from_video(video_assets[0].np_ndarrays, num_frames)
inputs = [(
[
"<|im_start|>user <video><video>\nDescribe 2 videos. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <video><video>\nDescribe 2 videos. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <video><video><video><video>\nDescribe 4 videos. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <video>\nwhy is this video funny? \
<|im_end|><|im_start|>assistant\n",
],
[
[video, video],
# Images with different sizes and aspect-ratios
[
rescale_video_size(video, 0.1),
video,
],
[
video,
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
video,
])]
run_video_test(
hf_runner,
vllm_runner,
video_assets,
inputs,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
num_frames=num_frames,
)
@ -303,7 +227,6 @@ def run_image_test(
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])

View File

@ -88,6 +88,7 @@ def dummy_image_for_clip(
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
num_videos: int = 1,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
@ -99,7 +100,8 @@ def dummy_video_for_clip(
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
video_data = [mm_data_per_video] * num_videos
mm_data = {"video": video_data}
return mm_data

View File

@ -43,19 +43,17 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
_MAX_NUM_VIDEOS = 1
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
list instead of a batched tensor.
"""
@ -213,11 +211,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos > _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
@ -232,7 +226,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
@ -243,7 +239,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
@ -315,7 +313,6 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
@ -336,10 +333,27 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
video_feature_size = []
for video in video_data:
num_frames = video.shape[0]
video_feature_size.append(
get_llava_onevision_video_tokens(ctx, num_frames))
msg = f"Unsupported vision config: {type(vision_config)}"
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
else:
raise TypeError(f"Invalid video type: {type(video_data)}")
msg = f"Unsupported video type: {type(video_data)}"
raise NotImplementedError(msg)
@ -723,6 +737,22 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def _add_image_newline(
self,
video_features: torch.Tensor,
videos: int = 1,
frames: int = 1,
strategy: str = "one_token",
) -> torch.Tensor:
if strategy == "one_token":
video_features = video_features.reshape(
videos, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(
videos, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
return video_features
raise ValueError(f"Unexpected video newline strategy: {strategy}")
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
@ -731,9 +761,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
b, num_videos, frames, c, h, w = pixel_values.shape
assert (num_videos == _MAX_NUM_VIDEOS)
pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
@ -741,13 +768,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(
b, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(b, 1, 1).to(
video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
return video_features
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
@ -755,10 +775,28 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
video_pixels = inputs["data"]
# TODO: support multiple videos per input
if isinstance(video_pixels, torch.Tensor):
b, num_videos, frames, c, h, w = video_pixels.shape
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, video_pixels)
self.vision_tower, pixel_values)
stacked_embeddings = self._add_image_newline(stacked_embeddings,
videos=b * num_videos,
frames=frames,
strategy="one_token")
return stacked_embeddings
elif is_list_of(video_pixels, torch.Tensor):
stacked_embeddings = []
for video_pixel in video_pixels:
num_videos, frames, c, h, w = video_pixel.shape
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
embeddings = self._video_pixels_to_features(
self.vision_tower, pixel_values)
embeddings = self._add_image_newline(embeddings,
videos=num_videos,
frames=frames,
strategy="one_token")
stacked_embeddings.append(embeddings)
return stacked_embeddings
else:
raise ValueError(

View File

@ -93,6 +93,7 @@ def dummy_image_for_siglip(
def dummy_video_for_siglip(
hf_config: SiglipVisionConfig,
num_frames: int,
num_videos: int = 1,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
@ -104,7 +105,8 @@ def dummy_video_for_siglip(
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
video_data = [mm_data_per_video] * num_videos
mm_data = {"video": video_data}
return mm_data

View File

@ -56,15 +56,14 @@ class VideoPlugin(ImagePlugin):
) -> MultiModalInputs:
model_config = ctx.model_config
# single video input as np.ndarray
if isinstance(data, np.ndarray):
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
video_processor = self._get_hf_video_processor(
model_config,
mm_processor_kwargs,
)
if video_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
"to process the video object")
try:
# NOTE: Similar to image; it may be a good idea to filter and
# pass mm_processor_kwargs here too, but for now we don't to
@ -72,13 +71,10 @@ class VideoPlugin(ImagePlugin):
# signatures of the processor don't align
batch_data = video_processor(data, return_tensors="pt").data
except Exception:
logger.error("Failed to process image (%s)", data)
logger.error("Failed to process video (%s)", data)
raise
return MultiModalInputs(batch_data)
elif is_list_of(data, np.ndarray):
raise NotImplementedError(
"Multi video for a prompt is not supported yet")
raise TypeError(f"Invalid video type: {type(data)}")