mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:04:57 +08:00
[ROCm][MLA] enable fp8 MLA decode on ROCm (#28032)
Signed-off-by: guanbao <gyu@amd.com> Signed-off-by: Guanbao Yu <gyu@amd.com> Signed-off-by: gbyu-amd <Guanbao.Yu@amd.com> Co-authored-by: guanbao <gyu@amd.com>
This commit is contained in:
parent
77e10c9cab
commit
cb7214d8ea
@ -294,6 +294,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
|||||||
kv_last_page_lens: torch.Tensor | None = None,
|
kv_last_page_lens: torch.Tensor | None = None,
|
||||||
sm_scale: float = 1.0,
|
sm_scale: float = 1.0,
|
||||||
logit_cap: float = 0.0,
|
logit_cap: float = 0.0,
|
||||||
|
q_scale: torch.Tensor | None = None,
|
||||||
|
kv_scale: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
from aiter.mla import mla_decode_fwd
|
from aiter.mla import mla_decode_fwd
|
||||||
|
|
||||||
@ -308,6 +310,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
|||||||
max_seqlen_qo,
|
max_seqlen_qo,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
q_scale=q_scale,
|
||||||
|
kv_scale=kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -322,6 +326,8 @@ def _rocm_aiter_mla_decode_fwd_fake(
|
|||||||
kv_last_page_lens: torch.Tensor | None = None,
|
kv_last_page_lens: torch.Tensor | None = None,
|
||||||
sm_scale: float = 1.0,
|
sm_scale: float = 1.0,
|
||||||
logit_cap: float = 0.0,
|
logit_cap: float = 0.0,
|
||||||
|
q_scale: torch.Tensor | None = None,
|
||||||
|
kv_scale: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -806,6 +812,8 @@ class rocm_aiter_ops:
|
|||||||
kv_indices: torch.Tensor | None = None,
|
kv_indices: torch.Tensor | None = None,
|
||||||
kv_last_page_lens: torch.Tensor | None = None,
|
kv_last_page_lens: torch.Tensor | None = None,
|
||||||
logit_cap: float = 0.0,
|
logit_cap: float = 0.0,
|
||||||
|
q_scale: torch.Tensor | None = None,
|
||||||
|
kv_scale: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||||
q,
|
q,
|
||||||
@ -818,6 +826,8 @@ class rocm_aiter_ops:
|
|||||||
kv_last_page_lens,
|
kv_last_page_lens,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
q_scale=q_scale,
|
||||||
|
kv_scale=kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -49,6 +49,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
|||||||
paged_kv_last_page_len: torch.Tensor | None = None
|
paged_kv_last_page_len: torch.Tensor | None = None
|
||||||
# The query indptr, shape : [num_decode + 1]
|
# The query indptr, shape : [num_decode + 1]
|
||||||
qo_indptr: torch.Tensor | None = None
|
qo_indptr: torch.Tensor | None = None
|
||||||
|
# The dtype of MLA out tensor
|
||||||
|
attn_out_dtype: torch.dtype = torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||||
@ -74,6 +76,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
self.decode_attn_out_dtype = vllm_config.model_config.dtype
|
||||||
# kernel block size is always 1.
|
# kernel block size is always 1.
|
||||||
max_num_pages_per_req = vllm_config.model_config.max_model_len
|
max_num_pages_per_req = vllm_config.model_config.max_model_len
|
||||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||||
@ -162,6 +165,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||||
qo_indptr=qo_indptr,
|
qo_indptr=qo_indptr,
|
||||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||||
|
attn_out_dtype=self.decode_attn_out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
@ -242,7 +246,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
assert isinstance(q, torch.Tensor)
|
assert isinstance(q, torch.Tensor)
|
||||||
B = q.shape[0]
|
B = q.shape[0]
|
||||||
o = torch.zeros(
|
o = torch.zeros(
|
||||||
B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
|
B,
|
||||||
|
self.num_heads,
|
||||||
|
self.kv_lora_rank,
|
||||||
|
dtype=attn_metadata.decode.attn_out_dtype,
|
||||||
|
device=q.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||||
@ -260,6 +268,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
attn_metadata.decode.paged_kv_indptr,
|
attn_metadata.decode.paged_kv_indptr,
|
||||||
attn_metadata.decode.paged_kv_indices,
|
attn_metadata.decode.paged_kv_indices,
|
||||||
attn_metadata.decode.paged_kv_last_page_len,
|
attn_metadata.decode.paged_kv_last_page_len,
|
||||||
|
q_scale=layer._q_scale,
|
||||||
|
kv_scale=layer._k_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o, None
|
return o, None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user