mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 13:42:18 +08:00
[ROCm] add fallback for aiter fp8 decode mla (#30005)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
899e2ef558
commit
b8a6ae4158
@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Cache whether aiter supports FP8 MLA parameters
|
||||||
|
_AITER_MLA_SUPPORTS_FP8: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _check_aiter_mla_fp8_support() -> bool:
|
||||||
|
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
|
||||||
|
global _AITER_MLA_SUPPORTS_FP8
|
||||||
|
if _AITER_MLA_SUPPORTS_FP8 is None:
|
||||||
|
try:
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from aiter.mla import mla_decode_fwd
|
||||||
|
|
||||||
|
sig = inspect.signature(mla_decode_fwd)
|
||||||
|
_AITER_MLA_SUPPORTS_FP8 = (
|
||||||
|
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
_AITER_MLA_SUPPORTS_FP8 = False
|
||||||
|
return _AITER_MLA_SUPPORTS_FP8
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_mla_decode_fwd_impl(
|
def _rocm_aiter_mla_decode_fwd_impl(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
kv_buffer: torch.Tensor,
|
kv_buffer: torch.Tensor,
|
||||||
@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
|||||||
) -> None:
|
) -> None:
|
||||||
from aiter.mla import mla_decode_fwd
|
from aiter.mla import mla_decode_fwd
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"sm_scale": sm_scale,
|
||||||
|
"logit_cap": logit_cap,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only pass q_scale and kv_scale if the aiter library supports them
|
||||||
|
if _check_aiter_mla_fp8_support():
|
||||||
|
kwargs["q_scale"] = q_scale
|
||||||
|
kwargs["kv_scale"] = kv_scale
|
||||||
|
|
||||||
mla_decode_fwd(
|
mla_decode_fwd(
|
||||||
q,
|
q,
|
||||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||||
@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
kv_last_page_lens,
|
kv_last_page_lens,
|
||||||
max_seqlen_qo,
|
max_seqlen_qo,
|
||||||
sm_scale=sm_scale,
|
**kwargs,
|
||||||
logit_cap=logit_cap,
|
|
||||||
q_scale=q_scale,
|
|
||||||
kv_scale=kv_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user