mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:05:28 +08:00
[Performance][MLA][ROCm] Remove redundant D2D copy in deepseek (#27457)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
53d7f1f601
commit
d9d342d214
@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
|
|||||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||||
const float* prefix_lse, const scalar_t* suffix_output,
|
const float* prefix_lse, const scalar_t* suffix_output,
|
||||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||||
const uint head_size) {
|
const uint head_size, const uint prefix_head_stride,
|
||||||
|
const uint output_head_stride) {
|
||||||
using pack_128b_t = uint4;
|
using pack_128b_t = uint4;
|
||||||
const uint pack_size = 16 / sizeof(scalar_t);
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
const uint threads_per_head = head_size / pack_size;
|
const uint threads_per_head = head_size / pack_size;
|
||||||
@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
|
|||||||
const uint head_idx = token_head_idx % num_heads;
|
const uint head_idx = token_head_idx % num_heads;
|
||||||
|
|
||||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||||
const uint head_offset =
|
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
|
||||||
token_idx * num_heads * head_size + head_idx * head_size;
|
head_idx * prefix_head_stride;
|
||||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
|
||||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
head_idx * output_head_stride;
|
||||||
scalar_t* output_head_ptr = output + head_offset;
|
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
|
||||||
|
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||||
|
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||||
|
|
||||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||||
@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
|
|||||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||||
num_heads, head_size); \
|
num_heads, head_size, prefix_head_stride, output_head_stride); \
|
||||||
}
|
}
|
||||||
|
|
||||||
/*@brief Merges the attention states from prefix and suffix
|
/*@brief Merges the attention states from prefix and suffix
|
||||||
@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
|||||||
const uint num_tokens = output.size(0);
|
const uint num_tokens = output.size(0);
|
||||||
const uint num_heads = output.size(1);
|
const uint num_heads = output.size(1);
|
||||||
const uint head_size = output.size(2);
|
const uint head_size = output.size(2);
|
||||||
|
const uint prefix_head_stride = prefix_output.stride(1);
|
||||||
|
const uint output_head_stride = output.stride(1);
|
||||||
const uint pack_size = 16 / sizeof(scalar_t);
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
TORCH_CHECK(head_size % pack_size == 0,
|
TORCH_CHECK(head_size % pack_size == 0,
|
||||||
"headsize must be multiple of pack_size:", pack_size);
|
"headsize must be multiple of pack_size:", pack_size);
|
||||||
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
|
|
||||||
"output heads must be contiguous in memory");
|
|
||||||
TORCH_CHECK(
|
|
||||||
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
|
|
||||||
"prefix_output heads must be contiguous in memory");
|
|
||||||
TORCH_CHECK(
|
|
||||||
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
|
|
||||||
"suffix_output heads must be contiguous in memory");
|
|
||||||
float* output_lse_ptr = nullptr;
|
float* output_lse_ptr = nullptr;
|
||||||
if (output_lse.has_value()) {
|
if (output_lse.has_value()) {
|
||||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||||
|
|||||||
@ -52,14 +52,13 @@ void paged_attention_v2(
|
|||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
void merge_attn_states(torch::Tensor& output,
|
void merge_attn_states(torch::Tensor& output,
|
||||||
std::optional<torch::Tensor> output_lse,
|
std::optional<torch::Tensor> output_lse,
|
||||||
const torch::Tensor& prefix_output,
|
const torch::Tensor& prefix_output,
|
||||||
const torch::Tensor& prefix_lse,
|
const torch::Tensor& prefix_lse,
|
||||||
const torch::Tensor& suffix_output,
|
const torch::Tensor& suffix_output,
|
||||||
const torch::Tensor& suffix_lse);
|
const torch::Tensor& suffix_lse);
|
||||||
|
#ifndef USE_ROCM
|
||||||
void convert_vertical_slash_indexes(
|
void convert_vertical_slash_indexes(
|
||||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||||
|
|||||||
@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
// Merge attn states
|
// Merge attn states
|
||||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||||
// can be used to combine partial attention results (in the split-KV case)
|
// can be used to combine partial attention results (in the split-KV case)
|
||||||
@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor suffix_output,"
|
" Tensor suffix_output,"
|
||||||
" Tensor suffix_lse) -> ()");
|
" Tensor suffix_lse) -> ()");
|
||||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||||
|
#ifndef USE_ROCM
|
||||||
ops.def(
|
ops.def(
|
||||||
"convert_vertical_slash_indexes("
|
"convert_vertical_slash_indexes("
|
||||||
" Tensor! block_count, Tensor! block_offset, "
|
" Tensor! block_count, Tensor! block_offset, "
|
||||||
|
|||||||
@ -20,7 +20,11 @@ def merge_attn_states(
|
|||||||
num_query_heads = output.shape[1]
|
num_query_heads = output.shape[1]
|
||||||
head_size = output.shape[2]
|
head_size = output.shape[2]
|
||||||
padded_head_size = triton.next_power_of_2(head_size)
|
padded_head_size = triton.next_power_of_2(head_size)
|
||||||
|
# We assume the output stride on num_head is not always as same as the
|
||||||
|
# `suffix_output` and `prefix_output`, as them might be padded by the attention
|
||||||
|
# backend.
|
||||||
|
prefix_head_stride = prefix_output.stride(1)
|
||||||
|
output_head_stride = output.stride(1)
|
||||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||||
output,
|
output,
|
||||||
@ -29,6 +33,8 @@ def merge_attn_states(
|
|||||||
prefix_lse,
|
prefix_lse,
|
||||||
suffix_output,
|
suffix_output,
|
||||||
suffix_lse,
|
suffix_lse,
|
||||||
|
prefix_head_stride,
|
||||||
|
output_head_stride,
|
||||||
head_size,
|
head_size,
|
||||||
padded_head_size,
|
padded_head_size,
|
||||||
output_lse is not None,
|
output_lse is not None,
|
||||||
@ -43,6 +49,8 @@ def merge_attn_states_kernel(
|
|||||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
prefix_head_stride,
|
||||||
|
output_head_stride,
|
||||||
HEAD_SIZE: tl.constexpr,
|
HEAD_SIZE: tl.constexpr,
|
||||||
PADDED_HEAD_SIZE: tl.constexpr,
|
PADDED_HEAD_SIZE: tl.constexpr,
|
||||||
OUTPUT_LSE: tl.constexpr,
|
OUTPUT_LSE: tl.constexpr,
|
||||||
@ -79,15 +87,15 @@ def merge_attn_states_kernel(
|
|||||||
head_mask = head_arange < HEAD_SIZE
|
head_mask = head_arange < HEAD_SIZE
|
||||||
p_out = tl.load(
|
p_out = tl.load(
|
||||||
prefix_output
|
prefix_output
|
||||||
+ token_idx * num_heads * HEAD_SIZE
|
+ token_idx * num_heads * prefix_head_stride
|
||||||
+ head_idx * HEAD_SIZE
|
+ head_idx * prefix_head_stride
|
||||||
+ head_arange,
|
+ head_arange,
|
||||||
mask=head_mask,
|
mask=head_mask,
|
||||||
)
|
)
|
||||||
s_out = tl.load(
|
s_out = tl.load(
|
||||||
suffix_output
|
suffix_output
|
||||||
+ token_idx * num_heads * HEAD_SIZE
|
+ token_idx * num_heads * prefix_head_stride
|
||||||
+ head_idx * HEAD_SIZE
|
+ head_idx * prefix_head_stride
|
||||||
+ head_arange,
|
+ head_arange,
|
||||||
mask=head_mask,
|
mask=head_mask,
|
||||||
)
|
)
|
||||||
@ -99,7 +107,10 @@ def merge_attn_states_kernel(
|
|||||||
s_scale = s_se / out_se
|
s_scale = s_se / out_se
|
||||||
out = p_out * p_scale + s_out * s_scale
|
out = p_out * p_scale + s_out * s_scale
|
||||||
tl.store(
|
tl.store(
|
||||||
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
output
|
||||||
|
+ token_idx * num_heads * output_head_stride
|
||||||
|
+ head_idx * output_head_stride
|
||||||
|
+ head_arange,
|
||||||
out,
|
out,
|
||||||
mask=head_mask,
|
mask=head_mask,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
|
|
||||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||||
|
out = out.view(-1, self.num_heads, self.v_head_dim)
|
||||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||||
x = rocm_aiter_ops.triton_fp8_bmm(
|
x = rocm_aiter_ops.triton_fp8_bmm(
|
||||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
|
||||||
)
|
)
|
||||||
# Convert from (B, N, V) to (B, N * V)
|
|
||||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
|
||||||
# Copy result
|
|
||||||
out.copy_(x)
|
|
||||||
else:
|
else:
|
||||||
# Convert from (B, N * V) to (N, B, V)
|
# Convert from (B, N * V) to (N, B, V)
|
||||||
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
||||||
@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
output: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
# TODO (zyongye): Prefill function here
|
# TODO (zyongye): Prefill function here
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
assert self.dcp_world_size is not None
|
assert self.dcp_world_size is not None
|
||||||
@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
|
|
||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||||
|
|
||||||
output = self._run_prefill_new_tokens(
|
output_prefill = self._run_prefill_new_tokens(
|
||||||
prefill=attn_metadata.prefill,
|
prefill=attn_metadata.prefill,
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if has_context:
|
if has_context:
|
||||||
suffix_output, suffix_lse = output
|
suffix_output, suffix_lse = output_prefill
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
context_output, context_lse = (
|
context_output, context_lse = (
|
||||||
self._context_parallel_compute_prefill_context(
|
self._context_parallel_compute_prefill_context(
|
||||||
@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
|
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
output = torch.empty_like(suffix_output)
|
# unpad if necessary
|
||||||
|
if self._pad_v:
|
||||||
|
context_output = context_output[..., : v.shape[-1]]
|
||||||
|
suffix_output = suffix_output[..., : v.shape[-1]]
|
||||||
|
|
||||||
|
output = output.view(-1, self.num_heads, self.v_head_dim)
|
||||||
merge_attn_states(
|
merge_attn_states(
|
||||||
output=output,
|
output=output,
|
||||||
prefix_output=context_output,
|
prefix_output=context_output,
|
||||||
@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
suffix_output=suffix_output,
|
suffix_output=suffix_output,
|
||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
# unpad if necessary
|
output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
|
||||||
if self._pad_v:
|
output.copy_(output_prefill)
|
||||||
output = output[..., : v.shape[-1]]
|
|
||||||
|
|
||||||
return output.flatten(start_dim=-2)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
output[num_decode_tokens:] = self._forward_prefill(
|
self._forward_prefill(
|
||||||
prefill_q,
|
prefill_q,
|
||||||
prefill_k_c_normed,
|
prefill_k_c_normed,
|
||||||
prefill_k_pe,
|
prefill_k_pe,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
layer._k_scale,
|
layer._k_scale,
|
||||||
|
output=output[num_decode_tokens:],
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user