diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 0100c082aa213..363aa08ef0030 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -211,8 +211,6 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON @@ -377,7 +375,6 @@ class MLACommonState(AttentionState, Generic[T]): seq_start_loc=None, context_lens_tensor=None, block_tables=self._graph_block_tables[:batch_size], - input_positions=self._positions[:batch_size], head_dim=self.runner.model_config.get_head_size()) if is_encoder_decoder_model: @@ -393,7 +390,6 @@ class MLACommonState(AttentionState, Generic[T]): "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, - "input_positions": attn_metadata.decode_metadata.input_positions, } if is_encoder_decoder_model: raise NotImplementedError( @@ -405,16 +401,10 @@ class MLACommonState(AttentionState, Generic[T]): input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - input_positions = attn_metadata.input_positions - num_positions = input_positions.shape[0] input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - # CUDA graph buffer is padded so only perform a partial copy based on - # num_positions - input_buffers["input_positions"][:num_positions].copy_( - input_positions, non_blocking=True) if is_encoder_decoder_model: raise NotImplementedError( "TritonMLAState does not support encoder/decoder yet") @@ -456,11 +446,6 @@ class MLACommonMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # New for MLA (compared to FlashAttention) - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -563,8 +548,6 @@ class MLACommonMetadata(AttentionMetadata): self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) - input_positions = (None if self.input_positions is None else - self.input_positions[:self.num_prefill_tokens]) self._cached_prefill_metadata = self.__class__( # Required by ModelRunner @@ -578,7 +561,6 @@ class MLACommonMetadata(AttentionMetadata): multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, # MLACommonMetadata - input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -615,8 +597,6 @@ class MLACommonMetadata(AttentionMetadata): self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - input_positions = (None if self.input_positions is None else - self.input_positions[self.num_prefill_tokens:]) self._cached_decode_metadata = self.__class__( # Required by ModelRunner @@ -646,7 +626,6 @@ class MLACommonMetadata(AttentionMetadata): if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, - input_positions=input_positions, head_dim=self.head_dim, is_profile_run=self.is_profile_run) return self._cached_decode_metadata @@ -765,7 +744,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] - self.input_positions: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -786,13 +764,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( + curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 @@ -912,8 +888,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - input_positions = async_tensor_h2d(self.input_positions, torch.long, - device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, @@ -987,7 +961,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): multi_modal_placeholder_index_maps=None, # Not Attention Related enable_kv_scales_calculation=False, # MLACommonMetadata - input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, @@ -1033,7 +1006,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - rotary_emb: RotaryEmbedding, kv_b_proj: ColumnParallelLinear, ) -> None: self.num_heads = num_heads @@ -1048,10 +1020,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - - self.rotary_emb = rotary_emb - self.use_yarn_rope = isinstance(rotary_emb, - DeepseekScalingRotaryEmbedding) self.kv_b_proj = kv_b_proj self.triton_fa_func = triton_attention @@ -1367,41 +1335,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): has_decode = attn_metadata.decode_metadata is not None has_prefill = attn_metadata.prefill_metadata is not None - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert hasattr(attn_metadata, "input_positions") - num_prefill_tokens: int = attn_metadata.num_prefill_tokens q = q.view(-1, self.num_heads, self.qk_head_dim) decode_q = q[num_prefill_tokens:] - decode_k_pe = k_pe[num_prefill_tokens:] - decode_input_positions = \ - attn_metadata.input_positions[num_prefill_tokens:] prefill_q = q[:num_prefill_tokens] prefill_k_pe = k_pe[:num_prefill_tokens] - prefill_input_positions = \ - attn_metadata.input_positions[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens] - if has_decode: - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - decode_input_positions, decode_q_pe, decode_k_pe) - - if has_prefill: - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - prefill_input_positions, prefill_q_pe, prefill_k_pe) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -1424,6 +1366,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): attn_metadata) if has_decode: + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + output[num_prefill_tokens:] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 2984bc1dad64a..4936c82013998 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -148,13 +148,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( + curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 32c2a2859b49f..f8392eb679d22 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -808,8 +808,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) + if self.cos_sin_cache.device != positions.device: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ce86b9b2c4f04..0366895ef02e0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -453,7 +453,6 @@ class DeepseekV2MLAAttention(nn.Module): qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, kv_b_proj=self.kv_b_proj, ) @@ -475,6 +474,13 @@ class DeepseekV2MLAAttention(nn.Module): [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + attn_out = self.mla_attn( q, kv_c_normed, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0040abeb183a7..0c740fbcc6b78 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,7 +204,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -269,9 +268,6 @@ class MLACommonPrefillMetadata: max_seq_lens: list[int] workspace: torch.Tensor - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -280,9 +276,6 @@ class MLACommonPrefillMetadata: @dataclass class MLACommonDecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor seq_lens: torch.Tensor @@ -443,10 +436,8 @@ class MLACommonMetadataBuilder(Generic[M]): return modified_batch - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, seq_lens: torch.Tensor): + def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( - input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, ) @@ -464,8 +455,6 @@ class MLACommonMetadataBuilder(Generic[M]): self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -473,7 +462,6 @@ class MLACommonMetadataBuilder(Generic[M]): prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start - tokens_start = self._num_decode_tokens context_lens_cpu = self.runner.input_batch.\ num_computed_tokens_cpu_tensor[reqs_start:num_reqs] @@ -542,7 +530,6 @@ class MLACommonMetadataBuilder(Generic[M]): self.chunked_prefill_workspace_size prefill_metadata = MLACommonPrefillMetadata( - input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, @@ -552,7 +539,6 @@ class MLACommonMetadataBuilder(Generic[M]): decode_metadata = None if self._num_decodes > 0: decode_metadata = self._build_decode( - input_positions=input_positions[:self._num_decode_tokens], block_table=block_table[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) @@ -599,7 +585,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - rotary_emb: RotaryEmbedding, kv_b_proj: ColumnParallelLinear, ) -> None: self.num_heads = num_heads @@ -614,15 +599,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - - # Hack for V1 for now to avoid torch library overhead (since we are - # already inside an attention custom op), pull out the forward - # method from the rotary embedding and call it directly - # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb.forward_native - if current_platform.is_cuda(): - self.rotary_emb = rotary_emb.forward_cuda - self.kv_b_proj = kv_b_proj self.vllm_flash_attn_version = get_flash_attn_version() @@ -894,9 +870,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None @@ -905,35 +878,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - q = q.view(-1, self.num_heads, self.qk_head_dim) decode_q = q[:num_decode_tokens] - decode_k_pe = k_pe[:num_decode_tokens] prefill_q = q[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] - if has_decode: - assert attn_metadata.decode is not None - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) - - if has_prefill: - assert attn_metadata.prefill is not None - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -951,6 +901,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_metadata) if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index f18c9c8b6462c..2f35f9b0a54f0 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -58,8 +58,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, + def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( @@ -69,7 +68,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ) return FlashMLADecodeMetadata( - input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, tile_scheduler_metadata=tile_scheduler_metadata,