[ROCm] [Bugfix] Fix torch sdpa hallucination (#30789)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-12-17 07:32:43 +08:00 committed by GitHub
parent 0a1ab1e565
commit 2410132bb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,7 @@ import einops
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()