diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1eb27d57acf06..2abfb457b84ae 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -353,8 +353,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_metadata.decode_wrapper = self._get_decode_wrapper() if not FlashInferBackend.use_trtllm_decode_attention( num_decodes, attn_metadata.max_seq_len, - attn_metadata.kv_data_type, attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, attn_metadata.head_dim): + self.cache_config.cache_dtype, + attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, + attn_metadata.head_dim): attn_metadata.decode_wrapper.plan( attn_metadata.paged_kv_indptr[:num_decodes + 1], attn_metadata.paged_kv_indices, @@ -539,10 +540,10 @@ class FlashInferImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: shape - + kv_cache: shape - # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] # HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - + attn_metadata: Metadata for attention. Returns: @@ -614,6 +615,7 @@ class FlashInferImpl(AttentionImpl): num_prefill_tokens = attn_metadata.num_prefill_tokens stride_order = FlashInferBackend.get_kv_cache_stride_order() + kv_cache_permute = kv_cache.permute(*stride_order) # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() @@ -628,7 +630,7 @@ class FlashInferImpl(AttentionImpl): assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, - kv_cache.permute(*stride_order), + kv_cache_permute, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[num_decode_tokens:], @@ -647,7 +649,7 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper._sm_scale == self.scale decode_wrapper.run( decode_query, - kv_cache.permute(*stride_order), + kv_cache_permute, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[:num_decode_tokens], @@ -655,19 +657,29 @@ class FlashInferImpl(AttentionImpl): else: # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND if num_decode_tokens > 0: + # decode_query may be non-contiguous + decode_query = decode_query.contiguous() + block_tables_decode = attn_metadata.block_table_tensor[: + num_decode_tokens] + seq_lens_decode = attn_metadata.seq_lens[: + num_decode_tokens] + assert get_kv_cache_layout() == "HND" + assert decode_query.is_contiguous() + assert kv_cache_permute.is_contiguous() + assert block_tables_decode.is_contiguous() + assert seq_lens_decode.is_contiguous() + output[:num_decode_tokens] = ( trtllm_batch_decode_with_kv_cache( query=decode_query, - kv_cache=kv_cache.permute(*stride_order), + kv_cache=kv_cache_permute, workspace_buffer=attn_metadata.workspace_buffer, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, scale=self.scale, - block_tables=attn_metadata. - block_table_tensor[:num_decode_tokens], - seq_lens=attn_metadata. - seq_lens[:num_decode_tokens], + block_tables=block_tables_decode, + seq_lens=seq_lens_decode, block_size=attn_metadata.page_size, max_seq_len=attn_metadata.max_seq_len, kv_cache_dtype=self.kv_cache_dtype,