From 6b3bfba12ce75c5eedf50a47a75cd0b76c6dfc76 Mon Sep 17 00:00:00 2001 From: yurekami Date: Wed, 24 Dec 2025 23:38:42 +0900 Subject: [PATCH] fix: reset FlashInfer wrappers after sleep mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FlashInfer's attention wrappers cache internal buffers that may be freed during sleep mode. After wake_up(), these cached wrappers still existed but pointed to freed memory, causing incorrect outputs. This fix: 1. Adds a `reset_after_sleep()` method to AttentionMetadataBuilder base class (no-op by default) 2. Overrides it in FlashInferMetadataBuilder to reset workspace buffer and all wrapper objects (prefill, decode, cascade, cudagraph) 3. Calls the reset method from gpu_worker's wake_up() for all attention groups The wrappers will be lazily recreated with fresh buffers on next use. Fixes #31016 Signed-off-by: yurekami <69337011+yurekami@users.noreply.github.com> 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: yurekami --- vllm/v1/attention/backends/flashinfer.py | 16 ++++++++++++++++ vllm/v1/attention/backends/utils.py | 14 ++++++++++++++ vllm/v1/worker/gpu_worker.py | 19 +++++++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 623ae892ecdaf..527da03594a7c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -718,6 +718,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return self._cascade_wrapper + @override + def reset_after_sleep(self) -> None: + """ + Reset FlashInfer wrappers after sleep mode. + + FlashInfer wrappers cache internal buffers that may be freed during + sleep mode. By resetting them to None, they will be lazily recreated + with fresh buffers on next use. + """ + self._workspace_buffer = None + self._prefill_wrapper = None + self._decode_wrapper = None + self._cascade_wrapper = None + if self.enable_cuda_graph: + self._decode_wrappers_cudagraph.clear() + def _compute_flashinfer_kv_metadata( self, num_blocks_np: np.ndarray, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 6b94f786a26b2..bc56c84577920 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -454,6 +454,20 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ) -> bool: return False + def reset_after_sleep(self) -> None: + """ + Reset internal state after sleep mode. + + Some attention backends cache internal buffers (e.g., FlashInfer + wrappers) that may be freed during sleep mode. This method should + be called after wake_up to invalidate those caches so they get + recreated with fresh buffers on next use. + + By default, this is a no-op. Backends that cache stateful objects + should override this method. + """ + pass + @functools.lru_cache def get_kv_cache_layout(): diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 68fe0853370f7..0b6d214e5d28d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -161,6 +161,25 @@ class Worker(WorkerBase): ): self.model_runner.init_fp8_kv_scales() + # Reset attention backend internal state after sleep. + # Some backends (e.g., FlashInfer) cache internal buffers that may be + # freed during sleep mode and need to be recreated on wake_up. + if tags is None or "kv_cache" in tags: + self._reset_attention_backends_after_sleep() + + def _reset_attention_backends_after_sleep(self) -> None: + """Reset attention backends after sleep mode. + + Iterates through all attention groups and calls reset_after_sleep() + on each metadata builder to invalidate cached internal state. + """ + if not hasattr(self.model_runner, "attn_groups"): + return + for kv_cache_group in self.model_runner.attn_groups: + for attn_group in kv_cache_group: + for builder in attn_group.metadata_builders: + builder.reset_after_sleep() + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator