diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 229d9862fb67..27d1e990c611 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel( scalar_t* output, float* output_lse, const scalar_t* prefix_output, const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, - const uint head_size) { + const uint head_size, const uint prefix_head_stride, + const uint output_head_stride) { using pack_128b_t = uint4; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; @@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel( const uint head_idx = token_head_idx % num_heads; const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. - const uint head_offset = - token_idx * num_heads * head_size + head_idx * head_size; - const scalar_t* prefix_head_ptr = prefix_output + head_offset; - const scalar_t* suffix_head_ptr = suffix_output + head_offset; - scalar_t* output_head_ptr = output + head_offset; + const uint src_head_offset = token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride; + const uint dst_head_offset = token_idx * num_heads * output_head_stride + + head_idx * output_head_stride; + const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; + const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; + scalar_t* output_head_ptr = output + dst_head_offset; float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; @@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel( reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size); \ + num_heads, head_size, prefix_head_stride, output_head_stride); \ } /*@brief Merges the attention states from prefix and suffix @@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint num_tokens = output.size(0); const uint num_heads = output.size(1); const uint head_size = output.size(2); + const uint prefix_head_stride = prefix_output.stride(1); + const uint output_head_stride = output.stride(1); const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); - TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, - "output heads must be contiguous in memory"); - TORCH_CHECK( - prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, - "prefix_output heads must be contiguous in memory"); - TORCH_CHECK( - suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, - "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/csrc/ops.h b/csrc/ops.h index f8bdc61aaa8e..4bb7857b1503 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,14 +52,13 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); -#ifndef USE_ROCM void merge_attn_states(torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); - +#ifndef USE_ROCM void convert_vertical_slash_indexes( torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 14913bef1312..e9c96bb8b56c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); -#ifndef USE_ROCM // Merge attn states // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // can be used to combine partial attention results (in the split-KV case) @@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_output," " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); - +#ifndef USE_ROCM ops.def( "convert_vertical_slash_indexes(" " Tensor! block_count, Tensor! block_offset, " diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 3c87a24afd9c..74e4d778ded8 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -20,7 +20,11 @@ def merge_attn_states( num_query_heads = output.shape[1] head_size = output.shape[2] padded_head_size = triton.next_power_of_2(head_size) - + # We assume the output stride on num_head is not always as same as the + # `suffix_output` and `prefix_output`, as them might be padded by the attention + # backend. + prefix_head_stride = prefix_output.stride(1) + output_head_stride = output.stride(1) # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. merge_attn_states_kernel[(num_tokens, num_query_heads)]( output, @@ -29,6 +33,8 @@ def merge_attn_states( prefix_lse, suffix_output, suffix_lse, + prefix_head_stride, + output_head_stride, head_size, padded_head_size, output_lse is not None, @@ -43,6 +49,8 @@ def merge_attn_states_kernel( prefix_lse, # [NUM_HEADS, NUM_TOKENS] suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] suffix_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_head_stride, + output_head_stride, HEAD_SIZE: tl.constexpr, PADDED_HEAD_SIZE: tl.constexpr, OUTPUT_LSE: tl.constexpr, @@ -79,15 +87,15 @@ def merge_attn_states_kernel( head_mask = head_arange < HEAD_SIZE p_out = tl.load( prefix_output - + token_idx * num_heads * HEAD_SIZE - + head_idx * HEAD_SIZE + + token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride + head_arange, mask=head_mask, ) s_out = tl.load( suffix_output - + token_idx * num_heads * HEAD_SIZE - + head_idx * HEAD_SIZE + + token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride + head_arange, mask=head_mask, ) @@ -99,7 +107,10 @@ def merge_attn_states_kernel( s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale tl.store( - output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + output + + token_idx * num_heads * output_head_stride + + head_idx * output_head_stride + + head_arange, out, mask=head_mask, ) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 87a3aac21d2c..d94ed9183f63 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if self.is_aiter_triton_fp8_bmm_enabled: + out = out.view(-1, self.num_heads, self.v_head_dim) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out ) - # Convert from (B, N, V) to (B, N * V) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - # Copy result - out.copy_(x) else: # Convert from (B, N * V) to (N, B, V) out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) @@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, - ) -> torch.Tensor: + output: torch.Tensor, + ) -> None: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None assert self.dcp_world_size is not None @@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = self._run_prefill_new_tokens( + output_prefill = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, q=q, k=k, @@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) if has_context: - suffix_output, suffix_lse = output + suffix_output, suffix_lse = output_prefill if self.dcp_world_size > 1: context_output, context_lse = ( self._context_parallel_compute_prefill_context( @@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q, kv_c_and_k_pe_cache, attn_metadata, k_scale ) - output = torch.empty_like(suffix_output) + # unpad if necessary + if self._pad_v: + context_output = context_output[..., : v.shape[-1]] + suffix_output = suffix_output[..., : v.shape[-1]] + + output = output.view(-1, self.num_heads, self.v_head_dim) merge_attn_states( output=output, prefix_output=context_output, @@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): suffix_output=suffix_output, suffix_lse=suffix_lse, ) - - # unpad if necessary - if self._pad_v: - output = output[..., : v.shape[-1]] - - return output.flatten(start_dim=-2) + else: + output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2) + output.copy_(output_prefill) @abstractmethod def _forward_decode( @@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_cache = kv_cache.view(current_platform.fp8_dtype()) if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( + self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata, layer._k_scale, + output=output[num_decode_tokens:], ) if has_decode: