[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (#25520)

This commit is contained in:
Benjamin Chislett 2025-09-23 20:19:56 -04:00 committed by GitHub
parent d06b5a95cb
commit 1983609239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8
logger = init_logger(__name__) logger = init_logger(__name__)
trtllm_gen_workspace_buffer = None
def _get_trtllm_gen_workspace_buffer():
global trtllm_gen_workspace_buffer
if trtllm_gen_workspace_buffer is None:
trtllm_gen_workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda')
return trtllm_gen_workspace_buffer
@triton.jit @triton.jit
def _trtllm_prefill_attn_kvfp8_dequant( def _trtllm_prefill_attn_kvfp8_dequant(
@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl):
else: else:
# prefill_query may be non-contiguous # prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous() prefill_query = prefill_query.contiguous()
workspace_buffer = prefill_wrapper._float_workspace_buffer workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[ block_tables_prefill = attn_metadata.block_table_tensor[
num_decode_tokens:] num_decode_tokens:]
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl):
else: else:
# decode_query may be non-contiguous # decode_query may be non-contiguous
decode_query = decode_query.contiguous() decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.\ block_tables_decode = attn_metadata.\
block_table_tensor[:num_decode_tokens] block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]