mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 21:59:18 +08:00
[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel (#21133)
This commit is contained in:
parent
8a8fc94639
commit
8dfb45ca33
@ -353,8 +353,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
num_decodes, attn_metadata.max_seq_len,
|
||||
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
self.cache_config.cache_dtype,
|
||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim):
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
@ -539,10 +540,10 @@ class FlashInferImpl(AttentionImpl):
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape -
|
||||
kv_cache: shape -
|
||||
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
|
||||
|
||||
|
||||
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
@ -614,6 +615,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||
kv_cache_permute = kv_cache.permute(*stride_order)
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
@ -628,7 +630,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache.permute(*stride_order),
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
@ -647,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache.permute(*stride_order),
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
@ -655,19 +657,29 @@ class FlashInferImpl(AttentionImpl):
|
||||
else:
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
if num_decode_tokens > 0:
|
||||
# decode_query may be non-contiguous
|
||||
decode_query = decode_query.contiguous()
|
||||
block_tables_decode = attn_metadata.block_table_tensor[:
|
||||
num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:
|
||||
num_decode_tokens]
|
||||
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert decode_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert block_tables_decode.is_contiguous()
|
||||
assert seq_lens_decode.is_contiguous()
|
||||
|
||||
output[:num_decode_tokens] = (
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache.permute(*stride_order),
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=attn_metadata.workspace_buffer,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.
|
||||
block_table_tensor[:num_decode_tokens],
|
||||
seq_lens=attn_metadata.
|
||||
seq_lens[:num_decode_tokens],
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
block_size=attn_metadata.page_size,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user