From dbb029cfe13c40a5725f62c8d4fedbd5539b3cc0 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:20:53 -0400 Subject: [PATCH] [Performance] Remove input pads in cutlass_mla and optimize v_proj output handling (#25184) Signed-off-by: Alexander Matveev Signed-off-by: yewentao256 --- vllm/v1/attention/backends/mla/common.py | 45 ++++++++++++++++--- vllm/v1/attention/backends/mla/cutlass_mla.py | 30 +++++++------ 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5b307810de930..9bca87c81d179 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -942,6 +942,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + q_pad_num_heads: Optional[int] = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -959,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj + self.q_pad_num_heads = q_pad_num_heads if use_flashinfer_prefill(): logger.debug_once("Using FlashInfer prefill for MLA") @@ -1134,7 +1136,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): True, #Indicates actual_seq_lens are on GPU or CPU. ) - def _v_up_proj(self, x): + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) if is_rocm_aiter_fp8bmm_enabled(): @@ -1146,12 +1148,23 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): transpose_bm=True) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) + # Copy result + out.copy_(x) else: + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return x + out_new = out.transpose(0, 1).reshape( + -1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -1559,6 +1572,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) + # Pads the head_dim if necessary (for the underlying kernel) + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty( + (B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + if is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, @@ -1567,8 +1589,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): group_size=128, transpose_bm=True) else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L)) + decode_ql_nope.resize_((N, B, L)) + + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) @@ -1603,5 +1636,5 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) # v_up projection - output[:num_decode_tokens] = self._v_up_proj(attn_out) + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index ae534f3207b51..d44e20f2cb6be 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -74,6 +74,8 @@ class SM100Workspace: g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB +MAX_HEADS = 128 + class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True @@ -92,10 +94,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + super().__init__(num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + q_pad_num_heads=MAX_HEADS, + **mla_args) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): @@ -157,14 +167,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): MAX_HEADS = 128 assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" - if H < MAX_HEADS: - q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) - q_nope_padded[:, :H] = q_nope - q_nope = q_nope_padded - - q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) - q_pe_padded[:, :H] = q_pe - q_pe = q_pe_padded assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape @@ -206,9 +208,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): ) if H < MAX_HEADS: + # Extract the subsets of the outputs + lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse out = out[:, :H] - if self.need_to_return_lse_for_decode: - lse = lse[:, :H].contiguous() return out, lse