mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:25:48 +08:00
[Bugfix] Ensure correctness of HCXVision processing (#23254)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
38217877aa
commit
4449235843
@ -102,7 +102,7 @@ def _test_processing_correctness(
|
||||
partial(random_video,
|
||||
rng,
|
||||
min_frames=2,
|
||||
max_frames=8,
|
||||
max_frames=16,
|
||||
min_wh=128,
|
||||
max_wh=256),
|
||||
"audio":
|
||||
|
||||
@ -53,6 +53,21 @@ IMAGE_TOKEN: str = "<|dummy3|>"
|
||||
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
|
||||
|
||||
|
||||
# Based on combine_frames_into_images in
|
||||
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
|
||||
def get_num_combined_frames(
|
||||
num_frames: int,
|
||||
max_grid_shape: tuple[int, int] = (3, 3),
|
||||
) -> int:
|
||||
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
|
||||
|
||||
# Calculate the number of canvases needed.
|
||||
num_canvases = num_frames // max_num_grids
|
||||
leftover_frames = num_frames % max_num_grids
|
||||
|
||||
return num_canvases + (leftover_frames > 0)
|
||||
|
||||
|
||||
class HCXVisionMultimodalPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values_images: list[torch.Tensor]
|
||||
@ -172,23 +187,20 @@ class HCXVisionMultiModalProcessor(
|
||||
def replace_multimodal_token(
|
||||
token_ids: torch.Tensor,
|
||||
target_token: int,
|
||||
repeats: list,
|
||||
repeats: list[int],
|
||||
):
|
||||
output = list()
|
||||
output = list[int]()
|
||||
_repeats_idx = 0
|
||||
for token_id in token_ids:
|
||||
if token_id == target_token:
|
||||
output += [
|
||||
token_id.item(),
|
||||
] * repeats[_repeats_idx]
|
||||
output += [token_id.item()] * repeats[_repeats_idx]
|
||||
_repeats_idx += 1
|
||||
else:
|
||||
output += [
|
||||
token_id.item(),
|
||||
]
|
||||
output += [token_id.item()]
|
||||
|
||||
return torch.tensor(output, device=token_ids.device)
|
||||
|
||||
for video_idx, video_arr in enumerate(mm_data.get("videos", list())):
|
||||
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
|
||||
if video_arr.dtype == np.uint8:
|
||||
continue
|
||||
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
||||
@ -205,88 +217,68 @@ class HCXVisionMultiModalProcessor(
|
||||
if len(mm_data) > 0:
|
||||
# batchify input as a single item
|
||||
images = mm_data.get("images", None)
|
||||
num_images = 0
|
||||
if images is not None:
|
||||
num_images = len(images)
|
||||
images = [
|
||||
images,
|
||||
] # batchify
|
||||
batched_images = None if images is None else [images]
|
||||
|
||||
videos = mm_data.get("videos",
|
||||
None) # list of video in single conversation
|
||||
num_videos = 0
|
||||
if videos is not None:
|
||||
num_videos = len(videos)
|
||||
videos = [
|
||||
videos,
|
||||
] # batchify
|
||||
# list of video in single conversation
|
||||
videos = mm_data.get("videos", None)
|
||||
batched_videos = None if videos is None else [videos]
|
||||
|
||||
_processed_outputs = self.info.ctx.call_hf_processor(
|
||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||
data=dict(
|
||||
text=None,
|
||||
images=images,
|
||||
videos=videos,
|
||||
images=batched_images,
|
||||
videos=batched_videos,
|
||||
),
|
||||
) # mm-only
|
||||
|
||||
for k, v in _processed_outputs.items():
|
||||
if len(v) < 1:
|
||||
continue
|
||||
elif k.endswith("_images"):
|
||||
# list of list of 4D tensor -> list of 4D tensor
|
||||
if isinstance(v, list) and len(v) > 0:
|
||||
assert len(v) == 1
|
||||
_processed_outputs[k] = v[0]
|
||||
elif k.endswith("_videos"):
|
||||
# list of list of 4D tensor -> list of 4D tensor
|
||||
v = v[0]
|
||||
if k == "pixel_values_videos":
|
||||
v = torch.cat(v, dim=0)
|
||||
_c, _w, _h = v.shape[-3:]
|
||||
v = v.reshape(num_videos, -1, _c, _w, _h)
|
||||
v = list(torch.unbind(v, dim=0))
|
||||
_processed_outputs[k] = v
|
||||
|
||||
if num_images > 0:
|
||||
if images:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
processed_outputs["input_ids"] = torch.stack([
|
||||
replace_multimodal_token(
|
||||
token_ids=_input_ids,
|
||||
target_token=tokenizer.convert_tokens_to_ids(
|
||||
IMAGE_TOKEN),
|
||||
target_token=image_token_id,
|
||||
repeats=_processed_outputs[
|
||||
"vision_query_lengths_images"],
|
||||
) for _input_ids in processed_outputs["input_ids"]
|
||||
],
|
||||
dim=0)
|
||||
|
||||
if num_videos > 0:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
processed_outputs["input_ids"] = torch.stack([
|
||||
replace_multimodal_token(
|
||||
token_ids=_input_ids,
|
||||
target_token=tokenizer.convert_tokens_to_ids(
|
||||
VIDEO_TOKEN),
|
||||
repeats=_processed_outputs[
|
||||
"vision_query_lengths_videos"],
|
||||
) for _input_ids in processed_outputs["input_ids"]
|
||||
],
|
||||
dim=0)
|
||||
|
||||
_ratios = [
|
||||
len(_pixel_values) for _pixel_values in
|
||||
_processed_outputs["pixel_values_videos"]
|
||||
]
|
||||
if videos:
|
||||
_num_per_videos = [
|
||||
int(_e / sum(_ratios) *
|
||||
len(_processed_outputs["vision_query_lengths_videos"]))
|
||||
for _e in _ratios
|
||||
get_num_combined_frames(len(video)) for video in videos
|
||||
]
|
||||
_processed_outputs["pixel_values_videos"] = [
|
||||
_processed_outputs["pixel_values_videos"]
|
||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||
for _i in range(len(videos))
|
||||
]
|
||||
_processed_outputs["vision_query_lengths_videos"] = [
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||
for _i in range(0, num_videos)
|
||||
for _i in range(len(videos))
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
|
||||
processed_outputs["input_ids"] = torch.stack([
|
||||
replace_multimodal_token(
|
||||
token_ids=_input_ids,
|
||||
target_token=video_token_id,
|
||||
repeats=[
|
||||
sum(lens) for lens in
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
],
|
||||
) for _input_ids in processed_outputs["input_ids"]
|
||||
],
|
||||
dim=0)
|
||||
|
||||
processed_outputs.update(_processed_outputs)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user