mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:55:55 +08:00
[Bugfix] Ensure correctness of HCXVision processing (#23254)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
38217877aa
commit
4449235843
@ -102,7 +102,7 @@ def _test_processing_correctness(
|
|||||||
partial(random_video,
|
partial(random_video,
|
||||||
rng,
|
rng,
|
||||||
min_frames=2,
|
min_frames=2,
|
||||||
max_frames=8,
|
max_frames=16,
|
||||||
min_wh=128,
|
min_wh=128,
|
||||||
max_wh=256),
|
max_wh=256),
|
||||||
"audio":
|
"audio":
|
||||||
|
|||||||
@ -53,6 +53,21 @@ IMAGE_TOKEN: str = "<|dummy3|>"
|
|||||||
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
|
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
|
||||||
|
|
||||||
|
|
||||||
|
# Based on combine_frames_into_images in
|
||||||
|
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
|
||||||
|
def get_num_combined_frames(
|
||||||
|
num_frames: int,
|
||||||
|
max_grid_shape: tuple[int, int] = (3, 3),
|
||||||
|
) -> int:
|
||||||
|
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
|
||||||
|
|
||||||
|
# Calculate the number of canvases needed.
|
||||||
|
num_canvases = num_frames // max_num_grids
|
||||||
|
leftover_frames = num_frames % max_num_grids
|
||||||
|
|
||||||
|
return num_canvases + (leftover_frames > 0)
|
||||||
|
|
||||||
|
|
||||||
class HCXVisionMultimodalPixelInputs(TypedDict):
|
class HCXVisionMultimodalPixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values_images: list[torch.Tensor]
|
pixel_values_images: list[torch.Tensor]
|
||||||
@ -172,23 +187,20 @@ class HCXVisionMultiModalProcessor(
|
|||||||
def replace_multimodal_token(
|
def replace_multimodal_token(
|
||||||
token_ids: torch.Tensor,
|
token_ids: torch.Tensor,
|
||||||
target_token: int,
|
target_token: int,
|
||||||
repeats: list,
|
repeats: list[int],
|
||||||
):
|
):
|
||||||
output = list()
|
output = list[int]()
|
||||||
_repeats_idx = 0
|
_repeats_idx = 0
|
||||||
for token_id in token_ids:
|
for token_id in token_ids:
|
||||||
if token_id == target_token:
|
if token_id == target_token:
|
||||||
output += [
|
output += [token_id.item()] * repeats[_repeats_idx]
|
||||||
token_id.item(),
|
|
||||||
] * repeats[_repeats_idx]
|
|
||||||
_repeats_idx += 1
|
_repeats_idx += 1
|
||||||
else:
|
else:
|
||||||
output += [
|
output += [token_id.item()]
|
||||||
token_id.item(),
|
|
||||||
]
|
|
||||||
return torch.tensor(output, device=token_ids.device)
|
return torch.tensor(output, device=token_ids.device)
|
||||||
|
|
||||||
for video_idx, video_arr in enumerate(mm_data.get("videos", list())):
|
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
|
||||||
if video_arr.dtype == np.uint8:
|
if video_arr.dtype == np.uint8:
|
||||||
continue
|
continue
|
||||||
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
||||||
@ -205,88 +217,68 @@ class HCXVisionMultiModalProcessor(
|
|||||||
if len(mm_data) > 0:
|
if len(mm_data) > 0:
|
||||||
# batchify input as a single item
|
# batchify input as a single item
|
||||||
images = mm_data.get("images", None)
|
images = mm_data.get("images", None)
|
||||||
num_images = 0
|
batched_images = None if images is None else [images]
|
||||||
if images is not None:
|
|
||||||
num_images = len(images)
|
|
||||||
images = [
|
|
||||||
images,
|
|
||||||
] # batchify
|
|
||||||
|
|
||||||
videos = mm_data.get("videos",
|
# list of video in single conversation
|
||||||
None) # list of video in single conversation
|
videos = mm_data.get("videos", None)
|
||||||
num_videos = 0
|
batched_videos = None if videos is None else [videos]
|
||||||
if videos is not None:
|
|
||||||
num_videos = len(videos)
|
|
||||||
videos = [
|
|
||||||
videos,
|
|
||||||
] # batchify
|
|
||||||
|
|
||||||
_processed_outputs = self.info.ctx.call_hf_processor(
|
_processed_outputs = self.info.ctx.call_hf_processor(
|
||||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||||
data=dict(
|
data=dict(
|
||||||
text=None,
|
text=None,
|
||||||
images=images,
|
images=batched_images,
|
||||||
videos=videos,
|
videos=batched_videos,
|
||||||
),
|
),
|
||||||
) # mm-only
|
) # mm-only
|
||||||
|
|
||||||
for k, v in _processed_outputs.items():
|
for k, v in _processed_outputs.items():
|
||||||
if len(v) < 1:
|
if isinstance(v, list) and len(v) > 0:
|
||||||
continue
|
assert len(v) == 1
|
||||||
elif k.endswith("_images"):
|
|
||||||
# list of list of 4D tensor -> list of 4D tensor
|
|
||||||
_processed_outputs[k] = v[0]
|
_processed_outputs[k] = v[0]
|
||||||
elif k.endswith("_videos"):
|
|
||||||
# list of list of 4D tensor -> list of 4D tensor
|
|
||||||
v = v[0]
|
|
||||||
if k == "pixel_values_videos":
|
|
||||||
v = torch.cat(v, dim=0)
|
|
||||||
_c, _w, _h = v.shape[-3:]
|
|
||||||
v = v.reshape(num_videos, -1, _c, _w, _h)
|
|
||||||
v = list(torch.unbind(v, dim=0))
|
|
||||||
_processed_outputs[k] = v
|
|
||||||
|
|
||||||
if num_images > 0:
|
if images:
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||||
processed_outputs["input_ids"] = torch.stack([
|
processed_outputs["input_ids"] = torch.stack([
|
||||||
replace_multimodal_token(
|
replace_multimodal_token(
|
||||||
token_ids=_input_ids,
|
token_ids=_input_ids,
|
||||||
target_token=tokenizer.convert_tokens_to_ids(
|
target_token=image_token_id,
|
||||||
IMAGE_TOKEN),
|
|
||||||
repeats=_processed_outputs[
|
repeats=_processed_outputs[
|
||||||
"vision_query_lengths_images"],
|
"vision_query_lengths_images"],
|
||||||
) for _input_ids in processed_outputs["input_ids"]
|
) for _input_ids in processed_outputs["input_ids"]
|
||||||
],
|
],
|
||||||
dim=0)
|
dim=0)
|
||||||
|
|
||||||
if num_videos > 0:
|
if videos:
|
||||||
tokenizer = self.info.get_tokenizer()
|
|
||||||
processed_outputs["input_ids"] = torch.stack([
|
|
||||||
replace_multimodal_token(
|
|
||||||
token_ids=_input_ids,
|
|
||||||
target_token=tokenizer.convert_tokens_to_ids(
|
|
||||||
VIDEO_TOKEN),
|
|
||||||
repeats=_processed_outputs[
|
|
||||||
"vision_query_lengths_videos"],
|
|
||||||
) for _input_ids in processed_outputs["input_ids"]
|
|
||||||
],
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
_ratios = [
|
|
||||||
len(_pixel_values) for _pixel_values in
|
|
||||||
_processed_outputs["pixel_values_videos"]
|
|
||||||
]
|
|
||||||
_num_per_videos = [
|
_num_per_videos = [
|
||||||
int(_e / sum(_ratios) *
|
get_num_combined_frames(len(video)) for video in videos
|
||||||
len(_processed_outputs["vision_query_lengths_videos"]))
|
]
|
||||||
for _e in _ratios
|
_processed_outputs["pixel_values_videos"] = [
|
||||||
|
_processed_outputs["pixel_values_videos"]
|
||||||
|
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||||
|
for _i in range(len(videos))
|
||||||
]
|
]
|
||||||
_processed_outputs["vision_query_lengths_videos"] = [
|
_processed_outputs["vision_query_lengths_videos"] = [
|
||||||
_processed_outputs["vision_query_lengths_videos"]
|
_processed_outputs["vision_query_lengths_videos"]
|
||||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||||
for _i in range(0, num_videos)
|
for _i in range(len(videos))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
|
||||||
|
processed_outputs["input_ids"] = torch.stack([
|
||||||
|
replace_multimodal_token(
|
||||||
|
token_ids=_input_ids,
|
||||||
|
target_token=video_token_id,
|
||||||
|
repeats=[
|
||||||
|
sum(lens) for lens in
|
||||||
|
_processed_outputs["vision_query_lengths_videos"]
|
||||||
|
],
|
||||||
|
) for _input_ids in processed_outputs["input_ids"]
|
||||||
|
],
|
||||||
|
dim=0)
|
||||||
|
|
||||||
processed_outputs.update(_processed_outputs)
|
processed_outputs.update(_processed_outputs)
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user