[ROCm] [Bugfix] compute_attn_mask_seqlen for qwen3 omni (#29974)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-12-04 16:20:24 +08:00 committed by GitHub
parent 9aa33a74b0
commit 3f1b03739a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -494,7 +494,10 @@ class Qwen3Omni_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen