[Bugfix] Fix workspace buffer None issue for Flashinfer TRTLLM Backend (#21525)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-07-29 22:34:00 +08:00 committed by GitHub
parent ad341c5194
commit 58b11b24a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 60 additions and 41 deletions

View File

@ -71,22 +71,20 @@ def benchmark_decode(
if kv_cache_dtype.startswith("fp8"): if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache) kv_cache, _ = to_float8(kv_cache)
output_trtllm = torch.empty(q.shape, dtype=dtype)
# Benchmark TRT decode # Benchmark TRT decode
def trt_decode(): def trt_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache( return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
q, q,
kv_cache, kv_cache,
workspace_buffer, workspace_buffer,
num_qo_heads,
num_kv_heads,
sm_scale,
block_tables, block_tables,
kv_lens_tensor, kv_lens_tensor,
page_size,
max_kv_len, max_kv_len,
kv_cache_dtype, bmm1_scale=k_scale * sm_scale,
k_scale, bmm2_scale=v_scale,
v_scale, out=output_trtllm,
) )
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
@ -125,6 +123,8 @@ def benchmark_decode(
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
output_baseline = torch.empty(q.shape, dtype=dtype)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, workspace_buffer,
kv_layout, kv_layout,
@ -145,7 +145,7 @@ def benchmark_decode(
) )
def baseline_decode(): def baseline_decode():
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale) return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
baseline_mean, baseline_std = time_fn(baseline_decode) baseline_mean, baseline_std = time_fn(baseline_decode)
@ -214,25 +214,39 @@ if __name__ == "__main__":
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = [] all_results = []
print("Running benchmark for kv_cache_dtype: bfloat16")
print( print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
) )
for max_seq_len in max_seq_lens: for max_seq_len in max_seq_lens:
for bs in num_seqs: for bs in num_seqs:
result = benchmark_decode( result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto" bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
) )
all_results.append(result) all_results.append(result)
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
print( print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
) )
for max_seq_len in max_seq_lens: for max_seq_len in max_seq_lens:
for bs in num_seqs: for bs in num_seqs:
result = benchmark_decode( result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8" bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="fp8",
) )
all_results.append(result) all_results.append(result)

View File

@ -113,27 +113,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_data_type=dtype, kv_data_type=dtype,
logits_soft_cap=soft_cap) logits_soft_cap=soft_cap)
output = wrapper.run(query, key_value_cache, scale) output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query, key_value_cache, scale, out=output)
# TRTLLM Decode # TRTLLM Decode
max_kv_len = max(kv_lens) max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens, kv_lens_tensor = torch.tensor(kv_lens,
dtype=torch.int, dtype=torch.int,
device=query.device) device=query.device)
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache( output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query.contiguous(), query.contiguous(),
key_value_cache, key_value_cache,
workspace_buffer, workspace_buffer,
num_query_heads,
num_kv_heads,
scale,
block_tables, block_tables,
kv_lens_tensor, kv_lens_tensor,
block_size,
max_kv_len, max_kv_len,
"auto", bmm1_scale=k_scale * scale,
k_scale, bmm2_scale=v_scale,
v_scale, out=output_trtllm,
) )
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \

View File

@ -1104,7 +1104,12 @@ class FlashInferImpl(AttentionImpl):
window_left = window_size[0] if window_size is not None else -1 window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None if num_decode_tokens > 0:
decode_output = torch.empty(decode_query.shape,
dtype=decode_query.dtype,
device=decode_query.device)
else:
decode_output = None
stride_order = FlashInferBackend.get_kv_cache_stride_order() stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill # We will use flash attention for prefill
@ -1155,17 +1160,18 @@ class FlashInferImpl(AttentionImpl):
num_decode_tokens, attn_metadata.max_decode_seq_len, num_decode_tokens, attn_metadata.max_decode_seq_len,
kv_cache_dtype, attn_metadata.num_qo_heads, kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.num_kv_heads, attn_metadata.head_dim):
decode_output = decode_meta.decode_wrapper.run( decode_meta.decode_wrapper.run(
decode_query, decode_query,
kv_cache.permute(*stride_order), kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
out=decode_output,
) )
else: else:
workspace_buffer = ( workspace_buffer = (
decode_meta.decode_wrapper._int_workspace_buffer) decode_meta.decode_wrapper._float_workspace_buffer)
assert FlashInferState.get_kv_cache_layout() == "HND" assert FlashInferState.get_kv_cache_layout() == "HND"
decode_output = trtllm_batch_decode_with_kv_cache( trtllm_batch_decode_with_kv_cache(
query=decode_query, query=decode_query,
kv_cache=kv_cache.permute(*stride_order), kv_cache=kv_cache.permute(*stride_order),
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
@ -1174,6 +1180,7 @@ class FlashInferImpl(AttentionImpl):
max_seq_len=attn_metadata.max_decode_seq_len, max_seq_len=attn_metadata.max_decode_seq_len,
bmm1_scale=layer._k_scale_float * softmax_scale, bmm1_scale=layer._k_scale_float * softmax_scale,
bmm2_scale=layer._v_scale_float, bmm2_scale=layer._v_scale_float,
out=decode_output,
) )
if prefill_output is None and decode_output is not None: if prefill_output is None and decode_output is not None:

View File

@ -194,7 +194,6 @@ class FlashInferMetadata:
max_seq_len: int max_seq_len: int
seq_lens: torch.Tensor seq_lens: torch.Tensor
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
workspace_buffer: torch.Tensor
# For handling prefill decode split # For handling prefill decode split
num_decodes: int num_decodes: int
@ -473,7 +472,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=seq_lens, seq_lens=seq_lens,
block_table_tensor=block_table_tensor, block_table_tensor=block_table_tensor,
workspace_buffer=self._get_workspace_buffer(),
) )
self._plan(num_prefills, num_decodes, attn_metadata) self._plan(num_prefills, num_decodes, attn_metadata)
@ -641,11 +639,11 @@ class FlashInferImpl(AttentionImpl):
if decode_wrapper := attn_metadata.decode_wrapper: if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens] decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not FlashInferBackend.use_trtllm_decode_attention( if not FlashInferBackend.use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len, attn_metadata.num_decodes, attn_metadata.max_seq_len,
self.kv_cache_dtype, attn_metadata.num_qo_heads, self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.num_kv_heads, attn_metadata.head_dim):
assert decode_wrapper is not None
assert decode_wrapper._window_left == window_left assert decode_wrapper._window_left == window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0) or 0.0)
@ -666,22 +664,24 @@ class FlashInferImpl(AttentionImpl):
num_decode_tokens] num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[: seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens] num_decode_tokens]
workspace_buffer = decode_wrapper._float_workspace_buffer
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous() assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous() assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous() assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous() assert seq_lens_decode.is_contiguous()
assert workspace_buffer.is_contiguous()
output[:num_decode_tokens] = ( trtllm_batch_decode_with_kv_cache(
trtllm_batch_decode_with_kv_cache( query=decode_query,
query=decode_query, kv_cache=kv_cache_permute,
kv_cache=kv_cache_permute, workspace_buffer=workspace_buffer,
workspace_buffer=attn_metadata.workspace_buffer, block_tables=block_tables_decode,
block_tables=block_tables_decode, seq_lens=seq_lens_decode,
seq_lens=seq_lens_decode, max_seq_len=attn_metadata.max_seq_len,
max_seq_len=attn_metadata.max_seq_len, bmm1_scale=layer._k_scale_float * self.scale,
bmm1_scale=layer._k_scale_float * self.scale, bmm2_scale=layer._v_scale_float,
bmm2_scale=layer._v_scale_float, out=output[:num_decode_tokens],
)) )
return output_padded return output_padded