mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:46:05 +08:00
[BugFix] Fix FI accuracy issue when used for MLA prefill (#26063)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
d00d652998
commit
decf7f794b
@ -1211,13 +1211,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
k, v, return_softmax_lse):
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
assert prefill.prefill_main is not None
|
||||
return prefill.prefill_main.run(
|
||||
ret = prefill.prefill_main.run(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
return_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||
return ret
|
||||
|
||||
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
|
||||
q, k, v, return_softmax_lse):
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
@ -1260,12 +1265,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int, q, k, v):
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
return prefill.prefill_chunks[chunk_idx].run(
|
||||
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
return_lse=True,
|
||||
)
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
|
||||
def _run_prefill_context_chunk_cudnn(self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user