diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index d5b1de834a61..02aecfad8281 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -102,7 +102,7 @@ def _test_processing_correctness( partial(random_video, rng, min_frames=2, - max_frames=8, + max_frames=16, min_wh=128, max_wh=256), "audio": diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index d3ddc47ea932..f8b30d8d98e5 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -53,6 +53,21 @@ IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" +# Based on combine_frames_into_images in +# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py +def get_num_combined_frames( + num_frames: int, + max_grid_shape: tuple[int, int] = (3, 3), +) -> int: + max_num_grids = max_grid_shape[0] * max_grid_shape[1] + + # Calculate the number of canvases needed. + num_canvases = num_frames // max_num_grids + leftover_frames = num_frames % max_num_grids + + return num_canvases + (leftover_frames > 0) + + class HCXVisionMultimodalPixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values_images: list[torch.Tensor] @@ -172,23 +187,20 @@ class HCXVisionMultiModalProcessor( def replace_multimodal_token( token_ids: torch.Tensor, target_token: int, - repeats: list, + repeats: list[int], ): - output = list() + output = list[int]() _repeats_idx = 0 for token_id in token_ids: if token_id == target_token: - output += [ - token_id.item(), - ] * repeats[_repeats_idx] + output += [token_id.item()] * repeats[_repeats_idx] _repeats_idx += 1 else: - output += [ - token_id.item(), - ] + output += [token_id.item()] + return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", list())): + for video_idx, video_arr in enumerate(mm_data.get("videos", [])): if video_arr.dtype == np.uint8: continue mm_data["videos"][video_idx] = video_arr.astype(np.uint8) @@ -205,88 +217,68 @@ class HCXVisionMultiModalProcessor( if len(mm_data) > 0: # batchify input as a single item images = mm_data.get("images", None) - num_images = 0 - if images is not None: - num_images = len(images) - images = [ - images, - ] # batchify + batched_images = None if images is None else [images] - videos = mm_data.get("videos", - None) # list of video in single conversation - num_videos = 0 - if videos is not None: - num_videos = len(videos) - videos = [ - videos, - ] # batchify + # list of video in single conversation + videos = mm_data.get("videos", None) + batched_videos = None if videos is None else [videos] _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=images, - videos=videos, + images=batched_images, + videos=batched_videos, ), ) # mm-only for k, v in _processed_outputs.items(): - if len(v) < 1: - continue - elif k.endswith("_images"): - # list of list of 4D tensor -> list of 4D tensor + if isinstance(v, list) and len(v) > 0: + assert len(v) == 1 _processed_outputs[k] = v[0] - elif k.endswith("_videos"): - # list of list of 4D tensor -> list of 4D tensor - v = v[0] - if k == "pixel_values_videos": - v = torch.cat(v, dim=0) - _c, _w, _h = v.shape[-3:] - v = v.reshape(num_videos, -1, _c, _w, _h) - v = list(torch.unbind(v, dim=0)) - _processed_outputs[k] = v - if num_images > 0: + if images: tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) processed_outputs["input_ids"] = torch.stack([ replace_multimodal_token( token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - IMAGE_TOKEN), + target_token=image_token_id, repeats=_processed_outputs[ "vision_query_lengths_images"], ) for _input_ids in processed_outputs["input_ids"] ], dim=0) - if num_videos > 0: - tokenizer = self.info.get_tokenizer() - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - VIDEO_TOKEN), - repeats=_processed_outputs[ - "vision_query_lengths_videos"], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) - - _ratios = [ - len(_pixel_values) for _pixel_values in - _processed_outputs["pixel_values_videos"] - ] + if videos: _num_per_videos = [ - int(_e / sum(_ratios) * - len(_processed_outputs["vision_query_lengths_videos"])) - for _e in _ratios + get_num_combined_frames(len(video)) for video in videos + ] + _processed_outputs["pixel_values_videos"] = [ + _processed_outputs["pixel_values_videos"] + [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] + for _i in range(len(videos)) ] _processed_outputs["vision_query_lengths_videos"] = [ _processed_outputs["vision_query_lengths_videos"] [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(0, num_videos) + for _i in range(len(videos)) ] + tokenizer = self.info.get_tokenizer() + video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) + processed_outputs["input_ids"] = torch.stack([ + replace_multimodal_token( + token_ids=_input_ids, + target_token=video_token_id, + repeats=[ + sum(lens) for lens in + _processed_outputs["vision_query_lengths_videos"] + ], + ) for _input_ids in processed_outputs["input_ids"] + ], + dim=0) + processed_outputs.update(_processed_outputs) return processed_outputs