[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (#25520)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Benjamin Chislett 2025-09-23 20:19:56 -04:00 committed by yewentao256
parent 342d17fb7f
commit cb825af948

View File

@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8
logger = init_logger(__name__)
trtllm_gen_workspace_buffer = None
def _get_trtllm_gen_workspace_buffer():
global trtllm_gen_workspace_buffer
if trtllm_gen_workspace_buffer is None:
trtllm_gen_workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda')
return trtllm_gen_workspace_buffer
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl):
else:
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
workspace_buffer = prefill_wrapper._float_workspace_buffer
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[
num_decode_tokens:]
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl):
else:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.\
block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]