minor: zero workspace buffer init for flashinfer trtllm-gen attn (#22603)

This commit is contained in:
eigen 2025-08-15 17:38:10 -04:00 committed by GitHub
parent 00d6cba0cf
commit 1723ef1aae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 4 deletions

View File

@ -113,7 +113,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
@ -247,7 +247,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout)
wrapper.plan(q_indptr,

View File

@ -203,7 +203,7 @@ class FlashInferState(AttentionState):
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
self._workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)

View File

@ -252,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
self._workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)