mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:44:56 +08:00
[Bugfix][sleepmode][fp8 kv cache]: Fix FP8 KV cache + sleep(level=2) gibberish output (#28783)
Signed-off-by: vensen <vensenmu@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
82c795d6f2
commit
66b5840287
@ -11,7 +11,7 @@ from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
from ..utils import create_new_process_for_each_test, requires_fp8
|
||||
|
||||
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
@ -243,3 +243,34 @@ def test_deep_sleep_async():
|
||||
assert output.outputs[0].text == output2.outputs[0].text
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
|
||||
@requires_fp8
|
||||
def test_deep_sleep_fp8_kvcache():
|
||||
GiB_bytes = 1 << 30
|
||||
model = "Qwen/Qwen2-0.5B"
|
||||
used_bytes_baseline = current_platform.get_current_memory_usage()
|
||||
|
||||
llm = LLM(model, enable_sleep_mode=True, kv_cache_dtype="fp8")
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# Put the engine to deep sleep
|
||||
llm.sleep(level=2)
|
||||
|
||||
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
|
||||
assert used_bytes < 3 * GiB_bytes
|
||||
|
||||
llm.wake_up(tags=["weights"])
|
||||
llm.collective_rpc("reload_weights")
|
||||
|
||||
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
|
||||
assert used_bytes < 4 * GiB_bytes
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
|
||||
@ -1075,6 +1075,13 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
|
||||
)
|
||||
|
||||
|
||||
requires_fp8 = pytest.mark.skipif(
|
||||
not current_platform.supports_fp8(),
|
||||
reason="FP8 is not supported on this GPU (requires Hopper or "
|
||||
"Ada architecture, compute capability 8.9+)",
|
||||
)
|
||||
|
||||
|
||||
def large_gpu_test(*, min_gb: int):
|
||||
"""
|
||||
Decorate a test to be skipped if no GPU is available or it does not have
|
||||
|
||||
@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
@ -602,6 +602,50 @@ class GPUModelRunner(
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
|
||||
@torch.inference_mode()
|
||||
def init_fp8_kv_scales(self) -> None:
|
||||
"""
|
||||
Re-initialize the KV cache and FP8 scales after waking from sleep.
|
||||
1. Zero out the KV cache tensors to remove garbage data from re-allocation.
|
||||
2. Reset Attention layer scaling factors (_k_scale, _v_scale) to 1.0.
|
||||
If these are left at 0.0 (default after wake_up), all KV cache values
|
||||
become effectively zero, causing gibberish output.
|
||||
"""
|
||||
if not self.cache_config.cache_dtype.startswith("fp8"):
|
||||
return
|
||||
|
||||
kv_caches = getattr(self, "kv_caches", [])
|
||||
for cache_tensor in kv_caches:
|
||||
if cache_tensor is not None:
|
||||
cache_tensor.zero_()
|
||||
|
||||
k_attr_names = ("_k_scale", "k_scale")
|
||||
v_attr_names = ("_v_scale", "v_scale")
|
||||
|
||||
attn_layers = self.compilation_config.static_forward_context
|
||||
for name, module in attn_layers.items():
|
||||
if isinstance(module, (Attention, MLAAttention)):
|
||||
# TODO: Generally, scale is 1.0 if user uses on-the-fly fp8
|
||||
# kvcache quant. However, to get better accuracy, compression
|
||||
# frameworks like llm-compressors allow users to tune the
|
||||
# scale. We may need to restore the specific calibrated scales
|
||||
# here in the future.
|
||||
k_scale_val, v_scale_val = 1.0, 1.0
|
||||
|
||||
# Processing K Scale
|
||||
for attr in k_attr_names:
|
||||
if hasattr(module, attr):
|
||||
param = getattr(module, attr)
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.fill_(k_scale_val)
|
||||
|
||||
# Processing V Scale
|
||||
for attr in v_attr_names:
|
||||
if hasattr(module, attr):
|
||||
param = getattr(module, attr)
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.fill_(v_scale_val)
|
||||
|
||||
def _get_positions(self, num_tokens: Any):
|
||||
if isinstance(num_tokens, int):
|
||||
if self.uses_mrope:
|
||||
|
||||
@ -141,6 +141,16 @@ class Worker(WorkerBase):
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
# If the KV cache has just been woken up,
|
||||
# the internal state of cache_engine must be reset,
|
||||
# especially the FP8 scaling factor.
|
||||
if (
|
||||
(tags is None or "kv_cache" in tags)
|
||||
and self.cache_config.cache_dtype.startswith("fp8")
|
||||
and hasattr(self.model_runner, "init_fp8_kv_scales")
|
||||
):
|
||||
self.model_runner.init_fp8_kv_scales()
|
||||
|
||||
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user