From a5a790eea6035760c71eae1861c1e5f369bc6d08 Mon Sep 17 00:00:00 2001 From: Adrian Abeyta Date: Mon, 10 Nov 2025 17:42:37 -0600 Subject: [PATCH] [Bugfix] Ensure calculated KV scales are applied in attention. (#27232) Signed-off-by: adabeyta --- .buildkite/test-pipeline.yaml | 7 +++++-- tests/compile/test_full_graph.py | 10 ++++++++-- vllm/attention/layer.py | 29 +++++++---------------------- vllm/v1/worker/gpu_model_runner.py | 19 +++++++++---------- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3152cd6488f36..a0d2076199b14 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -471,8 +471,8 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph.py - # Limit to no custom ops to reduce running time + - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time # Wrap with quotes to escape yaml and avoid starting -k string with a - - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" @@ -951,10 +951,13 @@ steps: - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - tests/compile/test_fusions_e2e.py + - tests/compile/test_full_graph.py commands: - nvidia-smi # Run all e2e fusion tests - pytest -v -s tests/compile/test_fusions_e2e.py + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 0ad8c17d86686..71f90f6d8d3ee 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -183,8 +183,14 @@ def test_custom_compile_config( "compilation_mode", [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) -def test_fp8_kv_scale_compile(compilation_mode: int): - model = "Qwen/Qwen2-0.5B" +@pytest.mark.parametrize( + "model", + [ + "Qwen/Qwen2-0.5B", # Standard attention model + "deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model + ], +) +def test_fp8_kv_scale_compile(compilation_mode: int, model: str): model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3", diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a431..96272981692c0 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase): k_pe: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: + if self.calculate_kv_scales: + torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) + if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # Mirror Attention.forward scale calculation path - if self.calculate_kv_scales and getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): - self.calc_kv_scales(q, kv_c_normed, k_pe) - if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.impl.forward( @@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ) return output else: - # We can still access forward context to check calculation flag - if self.calculate_kv_scales: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - if getattr(attn_metadata, "enable_kv_scales_calculation", False): - self.calc_kv_scales(q, kv_c_normed, k_pe) return torch.ops.vllm.unified_mla_attention( q, kv_c_normed, @@ -881,17 +870,13 @@ def maybe_calc_kv_scales( layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - - if attn_metadata is None or not getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): + # Only calculate if the layer's calculate_kv_scales flag is True + # This flag gets set to False after the first forward pass + if not self.calculate_kv_scales: return - self = forward_context.no_compile_layers[layer_name] self.calc_kv_scales(query, key, value) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9403b5756e052..6fccf2ea2f47c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + + # Always set to false after the first forward pass + self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens @@ -2625,16 +2628,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # Set cudagraph mode to none if calc_kv_scales is true. - if attn_metadata is not None: - metadata_list = ( - attn_metadata.values() - if isinstance(attn_metadata, dict) - else [attn_metadata] - ) - if any( - getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list - ): - cudagraph_runtime_mode = CUDAGraphMode.NONE + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_runtime_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False # Run the model. # Use persistent buffers for CUDA graphs.