diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4c577c1c47e7..7bb01507ac2c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -158,8 +158,13 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - torch.ops._C.rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox) + # TODO: Remove this contiguous call when the kernel is updated to support tensor slices + query_contiguous = query.contiguous() + key_contiguous = key.contiguous() + torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous, + head_size, cos_sin_cache, is_neox) + query.copy_(query_contiguous) + key.copy_(key_contiguous) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, + # TODO: Remove this contiguous call when the kernel is updated to support tensor slices + query_contiguous = query.contiguous() + key_contiguous = key.contiguous() + torch.ops._C.batched_rotary_embedding(positions, query_contiguous, + key_contiguous, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) + query.copy_(query_contiguous) + key.copy_(key_contiguous) # layer norm ops diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e6e483bae2bc..b032006d1ad1 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -938,8 +938,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe.contiguous(), - decode_k_pe) + attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None @@ -948,8 +947,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), prefill_k_pe) + attn_metadata.prefill.input_positions, prefill_q_pe, + prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: