mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 00:02:26 +08:00
[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (#25520)
This commit is contained in:
parent
d06b5a95cb
commit
1983609239
@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
trtllm_gen_workspace_buffer = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_trtllm_gen_workspace_buffer():
|
||||||
|
global trtllm_gen_workspace_buffer
|
||||||
|
if trtllm_gen_workspace_buffer is None:
|
||||||
|
trtllm_gen_workspace_buffer = torch.zeros(
|
||||||
|
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda')
|
||||||
|
return trtllm_gen_workspace_buffer
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _trtllm_prefill_attn_kvfp8_dequant(
|
def _trtllm_prefill_attn_kvfp8_dequant(
|
||||||
@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
# prefill_query may be non-contiguous
|
# prefill_query may be non-contiguous
|
||||||
prefill_query = prefill_query.contiguous()
|
prefill_query = prefill_query.contiguous()
|
||||||
workspace_buffer = prefill_wrapper._float_workspace_buffer
|
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||||
block_tables_prefill = attn_metadata.block_table_tensor[
|
block_tables_prefill = attn_metadata.block_table_tensor[
|
||||||
num_decode_tokens:]
|
num_decode_tokens:]
|
||||||
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
|
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
|
||||||
@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
# decode_query may be non-contiguous
|
# decode_query may be non-contiguous
|
||||||
decode_query = decode_query.contiguous()
|
decode_query = decode_query.contiguous()
|
||||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||||
block_tables_decode = attn_metadata.\
|
block_tables_decode = attn_metadata.\
|
||||||
block_table_tensor[:num_decode_tokens]
|
block_table_tensor[:num_decode_tokens]
|
||||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user