mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:06:01 +08:00
[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (#25520)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
342d17fb7f
commit
cb825af948
@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
trtllm_gen_workspace_buffer = None
|
||||
|
||||
|
||||
def _get_trtllm_gen_workspace_buffer():
|
||||
global trtllm_gen_workspace_buffer
|
||||
if trtllm_gen_workspace_buffer is None:
|
||||
trtllm_gen_workspace_buffer = torch.zeros(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda')
|
||||
return trtllm_gen_workspace_buffer
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _trtllm_prefill_attn_kvfp8_dequant(
|
||||
@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
else:
|
||||
# prefill_query may be non-contiguous
|
||||
prefill_query = prefill_query.contiguous()
|
||||
workspace_buffer = prefill_wrapper._float_workspace_buffer
|
||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||
block_tables_prefill = attn_metadata.block_table_tensor[
|
||||
num_decode_tokens:]
|
||||
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
|
||||
@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
else:
|
||||
# decode_query may be non-contiguous
|
||||
decode_query = decode_query.contiguous()
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||
block_tables_decode = attn_metadata.\
|
||||
block_table_tensor[:num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user