[Qwen3-Omni] fixed _get_feat_extract_output_lengths function (#31007)

Signed-off-by: Xiong Wang <wangxiongts@163.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Xiong Wang 2025-12-24 13:33:54 +08:00 committed by GitHub
parent 369f47aa0f
commit bb24592d13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -118,7 +118,7 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return feat_lengths, output_lengths
return output_lengths
class Qwen3_VisionPatchEmbed(nn.Module):
@ -921,13 +921,11 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if audio_feature_lengths is None and feature_attention_mask is None:
audio_output_lengths = []
elif audio_feature_lengths is not None:
_, audio_output_lens = _get_feat_extract_output_lengths(
audio_feature_lengths
)
audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_output_lengths = audio_output_lens.tolist()
elif feature_attention_mask is not None:
assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lens = _get_feat_extract_output_lengths(
audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1)
)
audio_output_lengths = audio_output_lens.tolist()
@ -1111,18 +1109,16 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
audio_input: Qwen2_5OmniAudioFeatureInputs,
audio_hashes: list[str] | None = None,
cached_audio_features: torch.Tensor | None = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, ...]:
input_features = audio_input["input_features"]
audio_feature_lengths = audio_input["audio_feature_lengths"]
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
audio_feature_lengths
)
audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_outputs = self.audio_tower(
input_features.to(self.audio_tower.dtype),
feature_lens=audio_feature_lengths,
aftercnn_lens=audio_feat_lengths,
aftercnn_lens=audio_output_lengths,
)
audio_features = audio_outputs.last_hidden_state
return audio_features.split(audio_output_lengths.tolist())
@ -1579,7 +1575,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
_, audio_len = _get_feat_extract_output_lengths(
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
llm_pos_ids = (
@ -1700,7 +1696,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
_, audio_len = _get_feat_extract_output_lengths(
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
audio_llm_pos_ids = (