[BugFix] Fix FI accuracy issue when used for MLA prefill (#26063)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Lucas Wilkinson 2025-10-02 13:18:13 -04:00 committed by GitHub
parent d00d652998
commit decf7f794b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1211,13 +1211,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k, v, return_softmax_lse):
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
return prefill.prefill_main.run(
ret = prefill.prefill_main.run(
q=q,
k=k,
v=v,
return_lse=return_softmax_lse,
)
if isinstance(ret, tuple):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return ret[0], ret[1].transpose(0, 1).contiguous()
return ret
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
q, k, v, return_softmax_lse):
assert isinstance(prefill, CudnnPrefillMetadata)
@ -1260,12 +1265,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v):
assert isinstance(prefill, FlashInferPrefillMetadata)
return prefill.prefill_chunks[chunk_idx].run(
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
v=v,
return_lse=True,
)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(self,
prefill: MLACommonPrefillMetadata,