[ROCm] add fallback for aiter fp8 decode mla (#30005)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi 2025-12-04 00:45:57 -08:00 committed by GitHub
parent 899e2ef558
commit b8a6ae4158
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
def _check_aiter_mla_fp8_support() -> bool:
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
global _AITER_MLA_SUPPORTS_FP8
if _AITER_MLA_SUPPORTS_FP8 is None:
try:
import inspect
from aiter.mla import mla_decode_fwd
sig = inspect.signature(mla_decode_fwd)
_AITER_MLA_SUPPORTS_FP8 = (
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
)
except Exception:
_AITER_MLA_SUPPORTS_FP8 = False
return _AITER_MLA_SUPPORTS_FP8
def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None:
from aiter.mla import mla_decode_fwd
kwargs = {
"sm_scale": sm_scale,
"logit_cap": logit_cap,
}
# Only pass q_scale and kv_scale if the aiter library supports them
if _check_aiter_mla_fp8_support():
kwargs["q_scale"] = q_scale
kwargs["kv_scale"] = kv_scale
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
q_scale=q_scale,
kv_scale=kv_scale,
**kwargs,
)