diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index a8f472d147a0d..35920d826578e 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake( pass +# Cache whether aiter supports FP8 MLA parameters +_AITER_MLA_SUPPORTS_FP8: bool | None = None + + +def _check_aiter_mla_fp8_support() -> bool: + """Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters.""" + global _AITER_MLA_SUPPORTS_FP8 + if _AITER_MLA_SUPPORTS_FP8 is None: + try: + import inspect + + from aiter.mla import mla_decode_fwd + + sig = inspect.signature(mla_decode_fwd) + _AITER_MLA_SUPPORTS_FP8 = ( + "q_scale" in sig.parameters and "kv_scale" in sig.parameters + ) + except Exception: + _AITER_MLA_SUPPORTS_FP8 = False + return _AITER_MLA_SUPPORTS_FP8 + + def _rocm_aiter_mla_decode_fwd_impl( q: torch.Tensor, kv_buffer: torch.Tensor, @@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl( ) -> None: from aiter.mla import mla_decode_fwd + kwargs = { + "sm_scale": sm_scale, + "logit_cap": logit_cap, + } + + # Only pass q_scale and kv_scale if the aiter library supports them + if _check_aiter_mla_fp8_support(): + kwargs["q_scale"] = q_scale + kwargs["kv_scale"] = kv_scale + mla_decode_fwd( q, kv_buffer.view(-1, 1, 1, q.shape[-1]), @@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl( kv_indices, kv_last_page_lens, max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap, - q_scale=q_scale, - kv_scale=kv_scale, + **kwargs, )