[Bugfix][Hardware][CPU] Enable Gemma2 with SDPA on CPU backend (#11169)

This commit is contained in:
Jani Monoses 2024-12-13 20:00:40 +02:00 committed by GitHub
parent 0920ab9131
commit 0a56bcc03d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import make_tensor_with_pad
from vllm.utils import make_tensor_with_pad, print_warning_once
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
@ -395,7 +395,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise ValueError(
"Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("Torch SPDA does not support logits soft cap.")
print_warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@ -619,7 +620,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and not self.need_mask,
is_causal=causal_attn and mask is None,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv