From 1723ef1aae749929c1cbddd964ab3ffd96452a70 Mon Sep 17 00:00:00 2001 From: eigen <52445717+yyihuang@users.noreply.github.com> Date: Fri, 15 Aug 2025 17:38:10 -0400 Subject: [PATCH] minor: zero workspace buffer init for flashinfer trtllm-gen attn (#22603) --- tests/kernels/attention/test_flashinfer_trtllm_attention.py | 4 ++-- vllm/attention/backends/flashinfer.py | 2 +- vllm/v1/attention/backends/flashinfer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 53e225ea3ea6c..4b84e6a00eceb 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -113,7 +113,7 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, @@ -247,7 +247,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout) wrapper.plan(q_indptr, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 208cacec38eb5..a85ec24632834 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -203,7 +203,7 @@ class FlashInferState(AttentionState): def _get_workspace_buffer(self): if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( + self._workspace_buffer = torch.zeros( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.runner.device) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 02decb171fc05..eac3f33e15096 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -252,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def _get_workspace_buffer(self): if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( + self._workspace_buffer = torch.zeros( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device)