mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:35:01 +08:00
[Bugfix] Enable PP with AITER+V1 (#19822)
Signed-off-by: Qiang Li <qiang.li2@amd.com>
This commit is contained in:
parent
e41bf15cd0
commit
e3a3e4db46
@ -45,7 +45,6 @@ def fused_add_rms_norm(
|
|||||||
|
|
||||||
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||||
variance_epsilon: float) -> torch.Tensor:
|
variance_epsilon: float) -> torch.Tensor:
|
||||||
|
|
||||||
import aiter as rocm_aiter
|
import aiter as rocm_aiter
|
||||||
if x.dim() > 2:
|
if x.dim() > 2:
|
||||||
x_original_shape = x.shape
|
x_original_shape = x.shape
|
||||||
|
|||||||
@ -201,16 +201,9 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
|
|
||||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||||
|
|
||||||
if self.num_heads == 16:
|
# max_seqlen_qo must be 1 except for MTP
|
||||||
# AITER MLA decode kernel only supports
|
# TODO: Find the best value for MTP
|
||||||
# max_seqlen_q=1 when using 16 heads.
|
|
||||||
max_seqlen_qo = 1
|
max_seqlen_qo = 1
|
||||||
else:
|
|
||||||
# AITER MLA decode Kernel handles arbitrary
|
|
||||||
# max_seqlen_q values when using 128 heads.
|
|
||||||
assert attn_metadata.prefill is not None
|
|
||||||
max_seqlen_qo = attn_metadata.prefill.max_query_len
|
|
||||||
|
|
||||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||||
attn_metadata.decode.qo_indptr, max_seqlen_qo,
|
attn_metadata.decode.qo_indptr, max_seqlen_qo,
|
||||||
attn_metadata.decode.paged_kv_indptr,
|
attn_metadata.decode.paged_kv_indptr,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user