mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 20:37:12 +08:00
[ROCm] [Bugfix] Fix torch sdpa hallucination (#30789)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> (cherry picked from commit 2410132bb1f9faa5b252fad3f2b83dc926946b08)
This commit is contained in:
parent
4cd332f3cf
commit
f34eca5f01
@ -16,6 +16,7 @@ import einops
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Never remove the contiguous logic for ROCm
|
||||||
|
# Without it, hallucinations occur with the backend
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user