[ROCm] [Bugfix] Fix torch sdpa hallucination (#30789)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
(cherry picked from commit 2410132bb1f9faa5b252fad3f2b83dc926946b08)
This commit is contained in:
TJian 2025-12-17 07:32:43 +08:00 committed by Kevin H. Luu
parent 4cd332f3cf
commit f34eca5f01

View File

@ -16,6 +16,7 @@ import einops
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()