[Bugfix][sleepmode][fp8 kv cache]: Fix FP8 KV cache + sleep(level=2) gibberish output (#28783)

Signed-off-by: vensen <vensenmu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Vensen 2025-11-30 14:24:25 +08:00 committed by GitHub
parent 82c795d6f2
commit 66b5840287
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 2 deletions

View File

@ -11,7 +11,7 @@ from vllm.device_allocator.cumem import CuMemAllocator
from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes
from ..utils import create_new_process_for_each_test
from ..utils import create_new_process_for_each_test, requires_fp8
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
@ -243,3 +243,34 @@ def test_deep_sleep_async():
assert output.outputs[0].text == output2.outputs[0].text
asyncio.run(test())
@requires_fp8
def test_deep_sleep_fp8_kvcache():
GiB_bytes = 1 << 30
model = "Qwen/Qwen2-0.5B"
used_bytes_baseline = current_platform.get_current_memory_usage()
llm = LLM(model, enable_sleep_mode=True, kv_cache_dtype="fp8")
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)
# Put the engine to deep sleep
llm.sleep(level=2)
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
assert used_bytes < 3 * GiB_bytes
llm.wake_up(tags=["weights"])
llm.collective_rpc("reload_weights")
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes
# now allocate kv cache and cuda graph memory
llm.wake_up(tags=["kv_cache"])
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text

View File

@ -1075,6 +1075,13 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
)
requires_fp8 = pytest.mark.skipif(
not current_platform.supports_fp8(),
reason="FP8 is not supported on this GPU (requires Hopper or "
"Ada architecture, compute capability 8.9+)",
)
def large_gpu_test(*, min_gb: int):
"""
Decorate a test to be skipped if no GPU is available or it does not have

View File

@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import (
AttentionType,
MultipleOf,
)
from vllm.attention.layer import Attention
from vllm.attention.layer import Attention, MLAAttention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@ -602,6 +602,50 @@ class GPUModelRunner(
if self.mm_budget:
self.mm_budget.reset_cache()
@torch.inference_mode()
def init_fp8_kv_scales(self) -> None:
"""
Re-initialize the KV cache and FP8 scales after waking from sleep.
1. Zero out the KV cache tensors to remove garbage data from re-allocation.
2. Reset Attention layer scaling factors (_k_scale, _v_scale) to 1.0.
If these are left at 0.0 (default after wake_up), all KV cache values
become effectively zero, causing gibberish output.
"""
if not self.cache_config.cache_dtype.startswith("fp8"):
return
kv_caches = getattr(self, "kv_caches", [])
for cache_tensor in kv_caches:
if cache_tensor is not None:
cache_tensor.zero_()
k_attr_names = ("_k_scale", "k_scale")
v_attr_names = ("_v_scale", "v_scale")
attn_layers = self.compilation_config.static_forward_context
for name, module in attn_layers.items():
if isinstance(module, (Attention, MLAAttention)):
# TODO: Generally, scale is 1.0 if user uses on-the-fly fp8
# kvcache quant. However, to get better accuracy, compression
# frameworks like llm-compressors allow users to tune the
# scale. We may need to restore the specific calibrated scales
# here in the future.
k_scale_val, v_scale_val = 1.0, 1.0
# Processing K Scale
for attr in k_attr_names:
if hasattr(module, attr):
param = getattr(module, attr)
if isinstance(param, torch.Tensor):
param.fill_(k_scale_val)
# Processing V Scale
for attr in v_attr_names:
if hasattr(module, attr):
param = getattr(module, attr)
if isinstance(param, torch.Tensor):
param.fill_(v_scale_val)
def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int):
if self.uses_mrope:

View File

@ -141,6 +141,16 @@ class Worker(WorkerBase):
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
# If the KV cache has just been woken up,
# the internal state of cache_engine must be reset,
# especially the FP8 scaling factor.
if (
(tags is None or "kv_cache" in tags)
and self.cache_config.cache_dtype.startswith("fp8")
and hasattr(self.model_runner, "init_fp8_kv_scales")
):
self.model_runner.init_fp8_kv_scales()
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator