From ca7a2d5f28eac9621474563cdda0e08596222755 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 8 Mar 2025 01:18:53 -0500 Subject: [PATCH] Revert "[Perf] Reduce MLA CPU overheads in V1 (#14384)" (#14471) --- vllm/model_executor/layers/rotary_embedding.py | 9 ++------- vllm/v1/attention/backends/mla/common.py | 15 ++++----------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 48cdebee9161b..64c2dac524f2b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -161,13 +161,8 @@ class RotaryEmbedding(CustomOp): ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) - # is expensive, so avoid calling it if possible - if self.cos_sin_cache.device != query.device or \ - self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f3fff585be646..886295ee895ca 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -222,8 +222,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.platforms import current_platform +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.utils import cdiv, round_down try: @@ -627,15 +627,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.v_head_dim = v_head_dim self.rotary_emb = rotary_emb - - if current_platform.is_cuda(): - # Hack for V1 for now to avoid torch library overhead (since we are - # already inside an attention custom op), pull out the forward - # method from the rotary embedding and call it directly (and avoid - # calling forward_native, when we can call forward_cuda) - # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb.forward_cuda - + self.use_yarn_rope = isinstance(rotary_emb, + DeepseekScalingRotaryEmbedding) self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj