mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:04:59 +08:00
[Bugfix][sleepmode][fp8 kv cache]: Fix FP8 KV cache + sleep(level=2) gibberish output (#28783)
Signed-off-by: vensen <vensenmu@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
82c795d6f2
commit
66b5840287
@ -11,7 +11,7 @@ from vllm.device_allocator.cumem import CuMemAllocator
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
|
||||||
from ..utils import create_new_process_for_each_test
|
from ..utils import create_new_process_for_each_test, requires_fp8
|
||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||||
@ -243,3 +243,34 @@ def test_deep_sleep_async():
|
|||||||
assert output.outputs[0].text == output2.outputs[0].text
|
assert output.outputs[0].text == output2.outputs[0].text
|
||||||
|
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
|
|
||||||
|
|
||||||
|
@requires_fp8
|
||||||
|
def test_deep_sleep_fp8_kvcache():
|
||||||
|
GiB_bytes = 1 << 30
|
||||||
|
model = "Qwen/Qwen2-0.5B"
|
||||||
|
used_bytes_baseline = current_platform.get_current_memory_usage()
|
||||||
|
|
||||||
|
llm = LLM(model, enable_sleep_mode=True, kv_cache_dtype="fp8")
|
||||||
|
prompt = "How are you?"
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||||
|
output = llm.generate(prompt, sampling_params)
|
||||||
|
|
||||||
|
# Put the engine to deep sleep
|
||||||
|
llm.sleep(level=2)
|
||||||
|
|
||||||
|
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
|
||||||
|
assert used_bytes < 3 * GiB_bytes
|
||||||
|
|
||||||
|
llm.wake_up(tags=["weights"])
|
||||||
|
llm.collective_rpc("reload_weights")
|
||||||
|
|
||||||
|
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
|
||||||
|
assert used_bytes < 4 * GiB_bytes
|
||||||
|
|
||||||
|
# now allocate kv cache and cuda graph memory
|
||||||
|
llm.wake_up(tags=["kv_cache"])
|
||||||
|
output2 = llm.generate(prompt, sampling_params)
|
||||||
|
|
||||||
|
# cmp output
|
||||||
|
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||||
|
|||||||
@ -1075,6 +1075,13 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
requires_fp8 = pytest.mark.skipif(
|
||||||
|
not current_platform.supports_fp8(),
|
||||||
|
reason="FP8 is not supported on this GPU (requires Hopper or "
|
||||||
|
"Ada architecture, compute capability 8.9+)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def large_gpu_test(*, min_gb: int):
|
def large_gpu_test(*, min_gb: int):
|
||||||
"""
|
"""
|
||||||
Decorate a test to be skipped if no GPU is available or it does not have
|
Decorate a test to be skipped if no GPU is available or it does not have
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention, MLAAttention
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
@ -602,6 +602,50 @@ class GPUModelRunner(
|
|||||||
if self.mm_budget:
|
if self.mm_budget:
|
||||||
self.mm_budget.reset_cache()
|
self.mm_budget.reset_cache()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def init_fp8_kv_scales(self) -> None:
|
||||||
|
"""
|
||||||
|
Re-initialize the KV cache and FP8 scales after waking from sleep.
|
||||||
|
1. Zero out the KV cache tensors to remove garbage data from re-allocation.
|
||||||
|
2. Reset Attention layer scaling factors (_k_scale, _v_scale) to 1.0.
|
||||||
|
If these are left at 0.0 (default after wake_up), all KV cache values
|
||||||
|
become effectively zero, causing gibberish output.
|
||||||
|
"""
|
||||||
|
if not self.cache_config.cache_dtype.startswith("fp8"):
|
||||||
|
return
|
||||||
|
|
||||||
|
kv_caches = getattr(self, "kv_caches", [])
|
||||||
|
for cache_tensor in kv_caches:
|
||||||
|
if cache_tensor is not None:
|
||||||
|
cache_tensor.zero_()
|
||||||
|
|
||||||
|
k_attr_names = ("_k_scale", "k_scale")
|
||||||
|
v_attr_names = ("_v_scale", "v_scale")
|
||||||
|
|
||||||
|
attn_layers = self.compilation_config.static_forward_context
|
||||||
|
for name, module in attn_layers.items():
|
||||||
|
if isinstance(module, (Attention, MLAAttention)):
|
||||||
|
# TODO: Generally, scale is 1.0 if user uses on-the-fly fp8
|
||||||
|
# kvcache quant. However, to get better accuracy, compression
|
||||||
|
# frameworks like llm-compressors allow users to tune the
|
||||||
|
# scale. We may need to restore the specific calibrated scales
|
||||||
|
# here in the future.
|
||||||
|
k_scale_val, v_scale_val = 1.0, 1.0
|
||||||
|
|
||||||
|
# Processing K Scale
|
||||||
|
for attr in k_attr_names:
|
||||||
|
if hasattr(module, attr):
|
||||||
|
param = getattr(module, attr)
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
param.fill_(k_scale_val)
|
||||||
|
|
||||||
|
# Processing V Scale
|
||||||
|
for attr in v_attr_names:
|
||||||
|
if hasattr(module, attr):
|
||||||
|
param = getattr(module, attr)
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
param.fill_(v_scale_val)
|
||||||
|
|
||||||
def _get_positions(self, num_tokens: Any):
|
def _get_positions(self, num_tokens: Any):
|
||||||
if isinstance(num_tokens, int):
|
if isinstance(num_tokens, int):
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
|
|||||||
@ -141,6 +141,16 @@ class Worker(WorkerBase):
|
|||||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||||
self._sleep_saved_buffers = {}
|
self._sleep_saved_buffers = {}
|
||||||
|
|
||||||
|
# If the KV cache has just been woken up,
|
||||||
|
# the internal state of cache_engine must be reset,
|
||||||
|
# especially the FP8 scaling factor.
|
||||||
|
if (
|
||||||
|
(tags is None or "kv_cache" in tags)
|
||||||
|
and self.cache_config.cache_dtype.startswith("fp8")
|
||||||
|
and hasattr(self.model_runner, "init_fp8_kv_scales")
|
||||||
|
):
|
||||||
|
self.model_runner.init_fp8_kv_scales()
|
||||||
|
|
||||||
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
|
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
|
||||||
if self.vllm_config.model_config.enable_sleep_mode:
|
if self.vllm_config.model_config.enable_sleep_mode:
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user