From 7568a282b90f012e199ff71e7813f186a51addec Mon Sep 17 00:00:00 2001 From: JartX Date: Wed, 29 Oct 2025 17:55:35 +0100 Subject: [PATCH] [FIXBUG] Qwen3VL hallucinations without Contiguous on Torch.SDPA (#27744) Signed-off-by: JartX Co-authored-by: Lukas Geiger --- vllm/model_executor/models/qwen2_5_vl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c68115729c425..41cb7084057dd 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -428,6 +428,14 @@ class Qwen2_5_VisionAttention(nn.Module): ) elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. + from vllm.platforms import current_platform + + # Never remove the next contiguous logic + # Without it, hallucinations occur with the backend + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1]