fix: reset FlashInfer wrappers after sleep mode

FlashInfer's attention wrappers cache internal buffers that may be freed
during sleep mode. After wake_up(), these cached wrappers still existed
but pointed to freed memory, causing incorrect outputs.

This fix:
1. Adds a `reset_after_sleep()` method to AttentionMetadataBuilder base
   class (no-op by default)
2. Overrides it in FlashInferMetadataBuilder to reset workspace buffer
   and all wrapper objects (prefill, decode, cascade, cudagraph)
3. Calls the reset method from gpu_worker's wake_up() for all attention
   groups

The wrappers will be lazily recreated with fresh buffers on next use.

Fixes #31016

Signed-off-by: yurekami <69337011+yurekami@users.noreply.github.com>

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: yurekami <yurekami@users.noreply.github.com>
This commit is contained in:
yurekami 2025-12-24 23:38:42 +09:00
parent 7cd288a4b3
commit 6b3bfba12c
3 changed files with 49 additions and 0 deletions

View File

@ -718,6 +718,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return self._cascade_wrapper
@override
def reset_after_sleep(self) -> None:
"""
Reset FlashInfer wrappers after sleep mode.
FlashInfer wrappers cache internal buffers that may be freed during
sleep mode. By resetting them to None, they will be lazily recreated
with fresh buffers on next use.
"""
self._workspace_buffer = None
self._prefill_wrapper = None
self._decode_wrapper = None
self._cascade_wrapper = None
if self.enable_cuda_graph:
self._decode_wrappers_cudagraph.clear()
def _compute_flashinfer_kv_metadata(
self,
num_blocks_np: np.ndarray,

View File

@ -454,6 +454,20 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
) -> bool:
return False
def reset_after_sleep(self) -> None:
"""
Reset internal state after sleep mode.
Some attention backends cache internal buffers (e.g., FlashInfer
wrappers) that may be freed during sleep mode. This method should
be called after wake_up to invalidate those caches so they get
recreated with fresh buffers on next use.
By default, this is a no-op. Backends that cache stateful objects
should override this method.
"""
pass
@functools.lru_cache
def get_kv_cache_layout():

View File

@ -161,6 +161,25 @@ class Worker(WorkerBase):
):
self.model_runner.init_fp8_kv_scales()
# Reset attention backend internal state after sleep.
# Some backends (e.g., FlashInfer) cache internal buffers that may be
# freed during sleep mode and need to be recreated on wake_up.
if tags is None or "kv_cache" in tags:
self._reset_attention_backends_after_sleep()
def _reset_attention_backends_after_sleep(self) -> None:
"""Reset attention backends after sleep mode.
Iterates through all attention groups and calls reset_after_sleep()
on each metadata builder to invalidate cached internal state.
"""
if not hasattr(self.model_runner, "attn_groups"):
return
for kv_cache_group in self.model_runner.attn_groups:
for attn_group in kv_cache_group:
for builder in attn_group.metadata_builders:
builder.reset_after_sleep()
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator