mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 18:45:35 +08:00
[Bugfix] Add contiguous call inside rope kernel wrapper (#17091)
Signed-off-by: 苏政渊 <suzhengyuan@moonshot.cn> Co-authored-by: 苏政渊 <suzhengyuan@moonshot.cn>
This commit is contained in:
parent
165cb56329
commit
17eb306fcc
@ -158,8 +158,13 @@ def rotary_embedding(
|
|||||||
cos_sin_cache: torch.Tensor,
|
cos_sin_cache: torch.Tensor,
|
||||||
is_neox: bool,
|
is_neox: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops._C.rotary_embedding(positions, query, key, head_size,
|
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
||||||
cos_sin_cache, is_neox)
|
query_contiguous = query.contiguous()
|
||||||
|
key_contiguous = key.contiguous()
|
||||||
|
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
|
||||||
|
head_size, cos_sin_cache, is_neox)
|
||||||
|
query.copy_(query_contiguous)
|
||||||
|
key.copy_(key_contiguous)
|
||||||
|
|
||||||
|
|
||||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||||
@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|||||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||||
rot_dim: int,
|
rot_dim: int,
|
||||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||||
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
||||||
|
query_contiguous = query.contiguous()
|
||||||
|
key_contiguous = key.contiguous()
|
||||||
|
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
|
||||||
|
key_contiguous, head_size,
|
||||||
cos_sin_cache, is_neox, rot_dim,
|
cos_sin_cache, is_neox, rot_dim,
|
||||||
cos_sin_cache_offsets)
|
cos_sin_cache_offsets)
|
||||||
|
query.copy_(query_contiguous)
|
||||||
|
key.copy_(key_contiguous)
|
||||||
|
|
||||||
|
|
||||||
# layer norm ops
|
# layer norm ops
|
||||||
|
|||||||
@ -938,8 +938,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
decode_ql_nope, decode_q_pe = \
|
decode_ql_nope, decode_q_pe = \
|
||||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
|
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
||||||
decode_k_pe)
|
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
@ -948,8 +947,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||||
|
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.prefill.input_positions,
|
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||||
prefill_q_pe.contiguous(), prefill_k_pe)
|
prefill_k_pe)
|
||||||
|
|
||||||
# write the latent and rope to kv cache
|
# write the latent and rope to kv cache
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user