mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
[Bugfix] Fix workspace buffer None issue for Flashinfer TRTLLM Backend (#21525)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
ad341c5194
commit
58b11b24a6
@ -71,22 +71,20 @@ def benchmark_decode(
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
output_trtllm = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
page_size,
|
||||
max_kv_len,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
bmm1_scale=k_scale * sm_scale,
|
||||
bmm2_scale=v_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
@ -125,6 +123,8 @@ def benchmark_decode(
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
output_baseline = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
@ -145,7 +145,7 @@ def benchmark_decode(
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
|
||||
@ -214,25 +214,39 @@ if __name__ == "__main__":
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print("Running benchmark for kv_cache_dtype: bfloat16")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="fp8",
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
@ -113,27 +113,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = wrapper.run(query, key_value_cache, scale)
|
||||
output = torch.empty(query.shape, dtype=dtype)
|
||||
wrapper.run(query, key_value_cache, scale, out=output)
|
||||
|
||||
# TRTLLM Decode
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens,
|
||||
dtype=torch.int,
|
||||
device=query.device)
|
||||
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query.contiguous(),
|
||||
key_value_cache,
|
||||
workspace_buffer,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
block_size,
|
||||
max_kv_len,
|
||||
"auto",
|
||||
k_scale,
|
||||
v_scale,
|
||||
bmm1_scale=k_scale * scale,
|
||||
bmm2_scale=v_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||
|
||||
@ -1104,7 +1104,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if num_decode_tokens > 0:
|
||||
decode_output = torch.empty(decode_query.shape,
|
||||
dtype=decode_query.dtype,
|
||||
device=decode_query.device)
|
||||
else:
|
||||
decode_output = None
|
||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
@ -1155,17 +1160,18 @@ class FlashInferImpl(AttentionImpl):
|
||||
num_decode_tokens, attn_metadata.max_decode_seq_len,
|
||||
kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
decode_output = decode_meta.decode_wrapper.run(
|
||||
decode_meta.decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=decode_output,
|
||||
)
|
||||
else:
|
||||
workspace_buffer = (
|
||||
decode_meta.decode_wrapper._int_workspace_buffer)
|
||||
decode_meta.decode_wrapper._float_workspace_buffer)
|
||||
assert FlashInferState.get_kv_cache_layout() == "HND"
|
||||
decode_output = trtllm_batch_decode_with_kv_cache(
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache.permute(*stride_order),
|
||||
workspace_buffer=workspace_buffer,
|
||||
@ -1174,6 +1180,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
max_seq_len=attn_metadata.max_decode_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * softmax_scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
out=decode_output,
|
||||
)
|
||||
|
||||
if prefill_output is None and decode_output is not None:
|
||||
|
||||
@ -194,7 +194,6 @@ class FlashInferMetadata:
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table_tensor: torch.Tensor
|
||||
workspace_buffer: torch.Tensor
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
@ -473,7 +472,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
workspace_buffer=self._get_workspace_buffer(),
|
||||
)
|
||||
|
||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
||||
@ -641,11 +639,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
||||
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
assert decode_wrapper is not None
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
@ -666,22 +664,24 @@ class FlashInferImpl(AttentionImpl):
|
||||
num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:
|
||||
num_decode_tokens]
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert decode_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert block_tables_decode.is_contiguous()
|
||||
assert seq_lens_decode.is_contiguous()
|
||||
assert workspace_buffer.is_contiguous()
|
||||
|
||||
output[:num_decode_tokens] = (
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=attn_metadata.workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
))
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
return output_padded
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user