diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 2ef66229b833d..363aa08ef0030 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1213,9 +1213,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): attn_output, attn_softmax_lse = \ self._flash_attn_varlen_diff_headdims( - q, - k, - v, + q=q, + k=k, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, @@ -1267,9 +1267,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) output = self._flash_attn_varlen_diff_headdims( - q, - k, - v, + q=q, + k=k, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index b048220020f14..4936c82013998 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -53,7 +53,7 @@ class AiterMLABackend(MLACommonBackend): @dataclass class AiterMLAMetadata(MLACommonMetadata): - # The following 5 tensors are for current version of AITER MLA + # The following 4 tensors are for current version of AITER MLA block_table_bound: Optional[torch.Tensor] = None # The indptr of the paged kv cache, shape: [batch_size + 1] paged_kv_indptr: Optional[torch.Tensor] = None @@ -63,10 +63,6 @@ class AiterMLAMetadata(MLACommonMetadata): # the paged kv cache, shape: [batch_size] paged_kv_last_page_lens: Optional[torch.Tensor] = None - # This is just to make new AITER MLA API work - # -- MTP support is not added yet. - qo_indptr: Optional[torch.Tensor] = None - @property def prefill_metadata(self): prefill_metadata = super().prefill_metadata @@ -78,7 +74,6 @@ class AiterMLAMetadata(MLACommonMetadata): prefill_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens prefill_metadata.block_table_bound = self.block_table_bound - prefill_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_prefill_metadata = self.__class__( @@ -98,7 +93,6 @@ class AiterMLAMetadata(MLACommonMetadata): decode_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens decode_metadata.block_table_bound = self.block_table_bound - decode_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_decode_metadata = self.__class__( @@ -142,7 +136,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): self.paged_kv_indptr: list[int] = [0] self.paged_kv_last_page_lens: list[int] = [] self.total_blocks = 0 - self.qo_indptr: list[int] = [0] def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, prefix_cache_hit: bool): @@ -215,7 +208,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): self.paged_kv_indices.extend(block_table[:block_table_bound]) self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + block_table_bound) - self.qo_indptr.append(self.qo_indptr[-1] + 1) last_page_len = seq_len % self.block_size if last_page_len == 0: @@ -234,8 +226,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): self.paged_kv_indptr.extend([last_paged_kv_indptr] * cuda_graph_pad_size) self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) - last_qo_indptr = self.qo_indptr[-1] - self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) # For current version of AITER MLA if len(self.paged_kv_indptr) > 0: @@ -255,22 +245,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): 1, device=device, dtype=torch.int) - - qo_indptr = torch.tensor(self.qo_indptr, - device=device, - dtype=torch.int) else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_lens_tensor = None block_table_bound_tensor = None - qo_indptr = None metadata.paged_kv_indptr = paged_kv_indptr_tensor metadata.paged_kv_indices = paged_kv_indices_tensor metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor metadata.block_table_bound = block_table_bound_tensor - metadata.qo_indptr = qo_indptr return metadata @@ -279,17 +263,14 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): @contextmanager def graph_capture(self, max_batch_size: int): - kv_indices, kv_indptr, last_page_lens, qo_indptr = \ - get_aiter_mla_metadata( - max_batch_size=max_batch_size, - block_size=self.runner.block_size, - max_block_per_batch=\ - self.runner.get_max_block_per_batch(), - device=self.runner.device) + kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=self.runner.get_max_block_per_batch(), + device=self.runner.device) self._paged_kv_indices_tensor = kv_indices self._paged_kv_indptr_tensor = kv_indptr self._paged_kv_last_page_lens_tensor = last_page_lens - self._qo_indptr_tensor = qo_indptr with super().graph_capture(max_batch_size): yield @@ -297,7 +278,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): del self._paged_kv_indices_tensor del self._paged_kv_indptr_tensor del self._paged_kv_last_page_lens_tensor - del self._qo_indptr_tensor def graph_capture_get_metadata_for_batch( self, @@ -311,12 +291,10 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): paged_kv_indices = self._paged_kv_indices_tensor paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: batch_size] - qo_indptr = self._qo_indptr_tensor[:batch_size + 1] metadata.paged_kv_indptr = paged_kv_indptr metadata.paged_kv_indices = paged_kv_indices metadata.paged_kv_last_page_lens = paged_kv_last_page_lens - metadata.qo_indptr = qo_indptr return metadata @@ -333,7 +311,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): input_buffers[ "paged_kv_last_page_lens"] = attn_metadata.\ decode_metadata.paged_kv_last_page_lens - input_buffers['qo_indptr'] = attn_metadata.qo_indptr return input_buffers @@ -353,8 +330,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): input_buffers["paged_kv_last_page_lens"].copy_( attn_metadata.decode_metadata.paged_kv_last_page_lens, non_blocking=True) - input_buffers["qo_indptr"].copy_( - attn_metadata.decode_metadata.qo_indptr, non_blocking=True) class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): @@ -395,9 +370,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): softmax_scale: float, return_softmax_lse: bool, **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: output = self.flash_attn_varlen_func( - q, - k, - v, + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, **kwargs, ) @@ -417,7 +394,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): B = q_nope.shape[0] q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.empty(B, + o = torch.zeros(B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, @@ -426,8 +403,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.qo_indptr, - attn_metadata.max_query_len, attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_lens) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index ce11ce12c81b4..3348d18804aab 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -20,8 +20,7 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, paged_kv_last_page_lens = torch.full((max_batch_size, ), block_size, dtype=torch.int32) - qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens def aiter_mla_decode_fwd( @@ -29,8 +28,6 @@ def aiter_mla_decode_fwd( kv_buffer: torch.Tensor, o: torch.Tensor, sm_scale: float, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, @@ -63,11 +60,9 @@ def mla_decode_fwd_impl( mla_decode_fwd(q, kv_buffer.view(-1, 1, 1, q.shape[-1]), o, - qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, - max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index b31af95248e38..7d7bce9ec6abc 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -123,11 +123,10 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, sorted_weight_buf, sorted_expert_ids, - num_valid_ids, topk, - a1_scale.t().contiguous(), - w1_scale.view(local_E, -1), - w2_scale.view(local_E, - -1), *block_shape, smooth_scale) + num_valid_ids, topk, w1_scale.view(local_E, -1), + w2_scale.view(local_E, -1), + a1_scale.t().contiguous(), *block_shape, + smooth_scale) return out_asm