[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel (#21133)

This commit is contained in:
elvischenv 2025-07-18 08:35:58 +08:00 committed by GitHub
parent 8a8fc94639
commit 8dfb45ca33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -353,8 +353,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
@ -539,10 +540,10 @@ class FlashInferImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape -
kv_cache: shape -
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
@ -614,6 +615,7 @@ class FlashInferImpl(AttentionImpl):
num_prefill_tokens = attn_metadata.num_prefill_tokens
stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
@ -628,7 +630,7 @@ class FlashInferImpl(AttentionImpl):
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache.permute(*stride_order),
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
@ -647,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache.permute(*stride_order),
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
@ -655,19 +657,29 @@ class FlashInferImpl(AttentionImpl):
else:
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
if num_decode_tokens > 0:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens]
assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
output[:num_decode_tokens] = (
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache.permute(*stride_order),
kv_cache=kv_cache_permute,
workspace_buffer=attn_metadata.workspace_buffer,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
scale=self.scale,
block_tables=attn_metadata.
block_table_tensor[:num_decode_tokens],
seq_lens=attn_metadata.
seq_lens[:num_decode_tokens],
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
block_size=attn_metadata.page_size,
max_seq_len=attn_metadata.max_seq_len,
kv_cache_dtype=self.kv_cache_dtype,