mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 17:17:31 +08:00
[ROCm] add fallback for aiter fp8 decode mla (#30005)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
899e2ef558
commit
b8a6ae4158
@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake(
|
||||
pass
|
||||
|
||||
|
||||
# Cache whether aiter supports FP8 MLA parameters
|
||||
_AITER_MLA_SUPPORTS_FP8: bool | None = None
|
||||
|
||||
|
||||
def _check_aiter_mla_fp8_support() -> bool:
|
||||
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
|
||||
global _AITER_MLA_SUPPORTS_FP8
|
||||
if _AITER_MLA_SUPPORTS_FP8 is None:
|
||||
try:
|
||||
import inspect
|
||||
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
sig = inspect.signature(mla_decode_fwd)
|
||||
_AITER_MLA_SUPPORTS_FP8 = (
|
||||
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
|
||||
)
|
||||
except Exception:
|
||||
_AITER_MLA_SUPPORTS_FP8 = False
|
||||
return _AITER_MLA_SUPPORTS_FP8
|
||||
|
||||
|
||||
def _rocm_aiter_mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
kwargs = {
|
||||
"sm_scale": sm_scale,
|
||||
"logit_cap": logit_cap,
|
||||
}
|
||||
|
||||
# Only pass q_scale and kv_scale if the aiter library supports them
|
||||
if _check_aiter_mla_fp8_support():
|
||||
kwargs["q_scale"] = q_scale
|
||||
kwargs["kv_scale"] = kv_scale
|
||||
|
||||
mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
q_scale=q_scale,
|
||||
kv_scale=kv_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user