mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 16:00:55 +08:00
[ROCm] [Bugfix] Fix torch sdpa hallucination (#30789)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
0a1ab1e565
commit
2410132bb1
@ -16,6 +16,7 @@ import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
# Without it, hallucinations occur with the backend
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user