diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 754ef20dbeb2..dc9c69bf58b9 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -11,7 +11,7 @@ from vllm.device_allocator.cumem import CuMemAllocator from vllm.platforms import current_platform from vllm.utils.mem_constants import GiB_bytes -from ..utils import create_new_process_for_each_test +from ..utils import create_new_process_for_each_test, requires_fp8 @create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn") @@ -243,3 +243,34 @@ def test_deep_sleep_async(): assert output.outputs[0].text == output2.outputs[0].text asyncio.run(test()) + + +@requires_fp8 +def test_deep_sleep_fp8_kvcache(): + GiB_bytes = 1 << 30 + model = "Qwen/Qwen2-0.5B" + used_bytes_baseline = current_platform.get_current_memory_usage() + + llm = LLM(model, enable_sleep_mode=True, kv_cache_dtype="fp8") + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + # Put the engine to deep sleep + llm.sleep(level=2) + + used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline + assert used_bytes < 3 * GiB_bytes + + llm.wake_up(tags=["weights"]) + llm.collective_rpc("reload_weights") + + used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline + assert used_bytes < 4 * GiB_bytes + + # now allocate kv cache and cuda graph memory + llm.wake_up(tags=["kv_cache"]) + output2 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text diff --git a/tests/utils.py b/tests/utils.py index c31a2aeeb9c8..9565b0ff06e3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1075,6 +1075,13 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator: ) +requires_fp8 = pytest.mark.skipif( + not current_platform.supports_fp8(), + reason="FP8 is not supported on this GPU (requires Hopper or " + "Ada architecture, compute capability 8.9+)", +) + + def large_gpu_test(*, min_gb: int): """ Decorate a test to be skipped if no GPU is available or it does not have diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9b0fb07297ac..eeae82568c32 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import ( AttentionType, MultipleOf, ) -from vllm.attention.layer import Attention +from vllm.attention.layer import Attention, MLAAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -602,6 +602,50 @@ class GPUModelRunner( if self.mm_budget: self.mm_budget.reset_cache() + @torch.inference_mode() + def init_fp8_kv_scales(self) -> None: + """ + Re-initialize the KV cache and FP8 scales after waking from sleep. + 1. Zero out the KV cache tensors to remove garbage data from re-allocation. + 2. Reset Attention layer scaling factors (_k_scale, _v_scale) to 1.0. + If these are left at 0.0 (default after wake_up), all KV cache values + become effectively zero, causing gibberish output. + """ + if not self.cache_config.cache_dtype.startswith("fp8"): + return + + kv_caches = getattr(self, "kv_caches", []) + for cache_tensor in kv_caches: + if cache_tensor is not None: + cache_tensor.zero_() + + k_attr_names = ("_k_scale", "k_scale") + v_attr_names = ("_v_scale", "v_scale") + + attn_layers = self.compilation_config.static_forward_context + for name, module in attn_layers.items(): + if isinstance(module, (Attention, MLAAttention)): + # TODO: Generally, scale is 1.0 if user uses on-the-fly fp8 + # kvcache quant. However, to get better accuracy, compression + # frameworks like llm-compressors allow users to tune the + # scale. We may need to restore the specific calibrated scales + # here in the future. + k_scale_val, v_scale_val = 1.0, 1.0 + + # Processing K Scale + for attr in k_attr_names: + if hasattr(module, attr): + param = getattr(module, attr) + if isinstance(param, torch.Tensor): + param.fill_(k_scale_val) + + # Processing V Scale + for attr in v_attr_names: + if hasattr(module, attr): + param = getattr(module, attr) + if isinstance(param, torch.Tensor): + param.fill_(v_scale_val) + def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): if self.uses_mrope: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d0c6091ce2a6..ed6fb32bcb2f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -141,6 +141,16 @@ class Worker(WorkerBase): buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} + # If the KV cache has just been woken up, + # the internal state of cache_engine must be reset, + # especially the FP8 scaling factor. + if ( + (tags is None or "kv_cache" in tags) + and self.cache_config.cache_dtype.startswith("fp8") + and hasattr(self.model_runner, "init_fp8_kv_scales") + ): + self.model_runner.init_fp8_kv_scales() + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator