Merge 6b3bfba12ce75c5eedf50a47a75cd0b76c6dfc76 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
ゆり 2025-12-25 00:07:07 +00:00 committed by GitHub
commit 39b54f7fdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 0 deletions

View File

@ -718,6 +718,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return self._cascade_wrapper
@override
def reset_after_sleep(self) -> None:
"""
Reset FlashInfer wrappers after sleep mode.
FlashInfer wrappers cache internal buffers that may be freed during
sleep mode. By resetting them to None, they will be lazily recreated
with fresh buffers on next use.
"""
self._workspace_buffer = None
self._prefill_wrapper = None
self._decode_wrapper = None
self._cascade_wrapper = None
if self.enable_cuda_graph:
self._decode_wrappers_cudagraph.clear()
def _compute_flashinfer_kv_metadata(
self,
num_blocks_np: np.ndarray,

View File

@ -454,6 +454,20 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
) -> bool:
return False
def reset_after_sleep(self) -> None:
"""
Reset internal state after sleep mode.
Some attention backends cache internal buffers (e.g., FlashInfer
wrappers) that may be freed during sleep mode. This method should
be called after wake_up to invalidate those caches so they get
recreated with fresh buffers on next use.
By default, this is a no-op. Backends that cache stateful objects
should override this method.
"""
pass
@functools.lru_cache
def get_kv_cache_layout():

View File

@ -161,6 +161,25 @@ class Worker(WorkerBase):
):
self.model_runner.init_fp8_kv_scales()
# Reset attention backend internal state after sleep.
# Some backends (e.g., FlashInfer) cache internal buffers that may be
# freed during sleep mode and need to be recreated on wake_up.
if tags is None or "kv_cache" in tags:
self._reset_attention_backends_after_sleep()
def _reset_attention_backends_after_sleep(self) -> None:
"""Reset attention backends after sleep mode.
Iterates through all attention groups and calls reset_after_sleep()
on each metadata builder to invalidate cached internal state.
"""
if not hasattr(self.model_runner, "attn_groups"):
return
for kv_cache_group in self.model_runner.attn_groups:
for attn_group in kv_cache_group:
for builder in attn_group.metadata_builders:
builder.reset_after_sleep()
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator