[Bugfix] Fix Qwen2.5-Omni M-RoPE position ids generation (#16878)

Signed-off-by: imkero <kerorek@outlook.com>
This commit is contained in:
Kero Liang 2025-04-27 01:41:35 +08:00 committed by GitHub
parent fd11a325b8
commit de7eb10ce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1209,6 +1209,7 @@ class MRotaryEmbedding(RotaryEmbedding):
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
@ -1238,8 +1239,15 @@ class MRotaryEmbedding(RotaryEmbedding):
if src_item[idx] not in [
audio_token_id, video_token_id, image_token_id
]:
if src_item[idx] == vision_end_token_id and use_audio_in_video:
start_idx -= 1
if use_audio_in_video and idx > 0:
if src_item[idx] == vision_end_token_id and \
src_item[idx - 1] == audio_end_token_id:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif src_item[idx] == audio_start_token_id and \
src_item[idx - 1] == vision_start_token_id:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx],
dtype=torch.long).expand(3, -1)
@ -1297,11 +1305,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
new_src_item.extend([audio_start_token_id])
start_idx -= 1
llm_pos_ids_list.extend([
torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
] * 1)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
@ -1312,7 +1315,7 @@ class MRotaryEmbedding(RotaryEmbedding):
new_src_item.extend([video_token_id] *
vision_ntoken_per_chunk)
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
start_idx + 1, video_idx, spatial_merge_size, t_chunk,
start_idx, video_idx, spatial_merge_size, t_chunk,
grid_hs, grid_ws).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
@ -1320,13 +1323,13 @@ class MRotaryEmbedding(RotaryEmbedding):
added_audio_len) * [audio_token_id])
audio_start_idx = start_idx if len(
audio_llm_pos_ids_list
) == 0 else audio_llm_pos_ids_list[-1][0].item()
) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
if min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx + 1).split(
audio_start_idx).split(
1, dim=1)
else:
audio_llm_pos_ids_list = []
@ -1341,11 +1344,6 @@ class MRotaryEmbedding(RotaryEmbedding):
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
llm_pos_ids_list.extend([
torch.tensor(
[llm_pos_ids_list[-1].max() + 1] * 3).unsqueeze(1)
] * 1)
new_src_item.extend([audio_end_token_id])
audio_idx += 1
video_idx += 1
# move to the next token