[Bugfix] Ensure correctness of HCXVision processing (#23254)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-20 22:19:30 +08:00 committed by GitHub
parent 38217877aa
commit 4449235843
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 63 deletions

View File

@ -102,7 +102,7 @@ def _test_processing_correctness(
partial(random_video,
rng,
min_frames=2,
max_frames=8,
max_frames=16,
min_wh=128,
max_wh=256),
"audio":

View File

@ -53,6 +53,21 @@ IMAGE_TOKEN: str = "<|dummy3|>"
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
# Based on combine_frames_into_images in
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
def get_num_combined_frames(
num_frames: int,
max_grid_shape: tuple[int, int] = (3, 3),
) -> int:
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
# Calculate the number of canvases needed.
num_canvases = num_frames // max_num_grids
leftover_frames = num_frames % max_num_grids
return num_canvases + (leftover_frames > 0)
class HCXVisionMultimodalPixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_images: list[torch.Tensor]
@ -172,23 +187,20 @@ class HCXVisionMultiModalProcessor(
def replace_multimodal_token(
token_ids: torch.Tensor,
target_token: int,
repeats: list,
repeats: list[int],
):
output = list()
output = list[int]()
_repeats_idx = 0
for token_id in token_ids:
if token_id == target_token:
output += [
token_id.item(),
] * repeats[_repeats_idx]
output += [token_id.item()] * repeats[_repeats_idx]
_repeats_idx += 1
else:
output += [
token_id.item(),
]
output += [token_id.item()]
return torch.tensor(output, device=token_ids.device)
for video_idx, video_arr in enumerate(mm_data.get("videos", list())):
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
if video_arr.dtype == np.uint8:
continue
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
@ -205,88 +217,68 @@ class HCXVisionMultiModalProcessor(
if len(mm_data) > 0:
# batchify input as a single item
images = mm_data.get("images", None)
num_images = 0
if images is not None:
num_images = len(images)
images = [
images,
] # batchify
batched_images = None if images is None else [images]
videos = mm_data.get("videos",
None) # list of video in single conversation
num_videos = 0
if videos is not None:
num_videos = len(videos)
videos = [
videos,
] # batchify
# list of video in single conversation
videos = mm_data.get("videos", None)
batched_videos = None if videos is None else [videos]
_processed_outputs = self.info.ctx.call_hf_processor(
hf_processor=self.info.get_hf_processor(**mm_kwargs),
data=dict(
text=None,
images=images,
videos=videos,
images=batched_images,
videos=batched_videos,
),
) # mm-only
for k, v in _processed_outputs.items():
if len(v) < 1:
continue
elif k.endswith("_images"):
# list of list of 4D tensor -> list of 4D tensor
if isinstance(v, list) and len(v) > 0:
assert len(v) == 1
_processed_outputs[k] = v[0]
elif k.endswith("_videos"):
# list of list of 4D tensor -> list of 4D tensor
v = v[0]
if k == "pixel_values_videos":
v = torch.cat(v, dim=0)
_c, _w, _h = v.shape[-3:]
v = v.reshape(num_videos, -1, _c, _w, _h)
v = list(torch.unbind(v, dim=0))
_processed_outputs[k] = v
if num_images > 0:
if images:
tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=tokenizer.convert_tokens_to_ids(
IMAGE_TOKEN),
target_token=image_token_id,
repeats=_processed_outputs[
"vision_query_lengths_images"],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
if num_videos > 0:
tokenizer = self.info.get_tokenizer()
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=tokenizer.convert_tokens_to_ids(
VIDEO_TOKEN),
repeats=_processed_outputs[
"vision_query_lengths_videos"],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
_ratios = [
len(_pixel_values) for _pixel_values in
_processed_outputs["pixel_values_videos"]
]
if videos:
_num_per_videos = [
int(_e / sum(_ratios) *
len(_processed_outputs["vision_query_lengths_videos"]))
for _e in _ratios
get_num_combined_frames(len(video)) for video in videos
]
_processed_outputs["pixel_values_videos"] = [
_processed_outputs["pixel_values_videos"]
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
for _i in range(len(videos))
]
_processed_outputs["vision_query_lengths_videos"] = [
_processed_outputs["vision_query_lengths_videos"]
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
for _i in range(0, num_videos)
for _i in range(len(videos))
]
tokenizer = self.info.get_tokenizer()
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=video_token_id,
repeats=[
sum(lens) for lens in
_processed_outputs["vision_query_lengths_videos"]
],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
processed_outputs.update(_processed_outputs)
return processed_outputs