mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 19:00:56 +08:00
minor: zero workspace buffer init for flashinfer trtllm-gen attn (#22603)
This commit is contained in:
parent
00d6cba0cf
commit
1723ef1aae
@ -113,7 +113,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
@ -247,7 +247,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout)
|
||||
wrapper.plan(q_indptr,
|
||||
|
||||
@ -203,7 +203,7 @@ class FlashInferState(AttentionState):
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
self._workspace_buffer = torch.zeros(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.runner.device)
|
||||
|
||||
@ -252,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
self._workspace_buffer = torch.zeros(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user