[BugFix] Fix mla cpu - missing 3 required positional arguments (#17494)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-05-01 02:36:52 -04:00 committed by GitHub
parent 13cf6b6236
commit 3c3d767201
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 1 deletions

View File

@ -177,7 +177,7 @@ class ipex_ops:
out: torch.Tensor,
seqlen_q: torch.Tensor,
seqlen_k: torch.Tensor,
alibi_slopes: torch.Tensor,
alibi_slopes: Optional[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k: int,
pdropout: float,
@ -193,6 +193,8 @@ class ipex_ops:
if ipex.__version__.endswith("cpu"):
if logits_soft_cap != 0.0:
raise ValueError("IPEX CPU does not support logits_soft_cap")
assert alibi_slopes is None
assert window_size_left < 0 and window_size_right < 0
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,

View File

@ -273,6 +273,9 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
return_softmax=False,
gen_=None,
logits_soft_cap=0.0,
window_size_left=-1,
window_size_right=-1,
alibi_slopes=None,
)
# remove padding