From cb825af948b614e23d08ce4853feae96e517bc21 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 23 Sep 2025 20:19:56 -0400 Subject: [PATCH] [Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (#25520) Signed-off-by: yewentao256 --- vllm/v1/attention/backends/flashinfer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb092aa74e7f1..1a5c171430bc6 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8 logger = init_logger(__name__) +trtllm_gen_workspace_buffer = None + + +def _get_trtllm_gen_workspace_buffer(): + global trtllm_gen_workspace_buffer + if trtllm_gen_workspace_buffer is None: + trtllm_gen_workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') + return trtllm_gen_workspace_buffer + @triton.jit def _trtllm_prefill_attn_kvfp8_dequant( @@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl): else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() - workspace_buffer = prefill_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_prefill = attn_metadata.block_table_tensor[ num_decode_tokens:] seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] @@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl): else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() - workspace_buffer = decode_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_decode = attn_metadata.\ block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]