diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb092aa74e7f1..1a5c171430bc6 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8 logger = init_logger(__name__) +trtllm_gen_workspace_buffer = None + + +def _get_trtllm_gen_workspace_buffer(): + global trtllm_gen_workspace_buffer + if trtllm_gen_workspace_buffer is None: + trtllm_gen_workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') + return trtllm_gen_workspace_buffer + @triton.jit def _trtllm_prefill_attn_kvfp8_dequant( @@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl): else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() - workspace_buffer = prefill_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_prefill = attn_metadata.block_table_tensor[ num_decode_tokens:] seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] @@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl): else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() - workspace_buffer = decode_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_decode = attn_metadata.\ block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]