mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 02:45:02 +08:00
[ROCm] [Bugfix] compute_attn_mask_seqlen for qwen3 omni (#29974)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
9aa33a74b0
commit
3f1b03739a
@ -494,7 +494,10 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
if self.attn_backend in {
|
||||||
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
|
}:
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
return max_seqlen
|
return max_seqlen
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user