[Bugfix] Ensure calculated KV scales are applied in attention. (#27232)

Signed-off-by: adabeyta <aabeyta@redhat.com>
This commit is contained in:
Adrian Abeyta 2025-11-10 17:42:37 -06:00 committed by GitHub
parent b30372cbd0
commit a5a790eea6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 36 deletions

View File

@ -471,8 +471,8 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
# Limit to no custom ops to reduce running time
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a -
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
@ -951,10 +951,13 @@ steps:
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
- tests/compile/test_fusions_e2e.py
- tests/compile/test_full_graph.py
commands:
- nvidia-smi
# Run all e2e fusion tests
- pytest -v -s tests/compile/test_fusions_e2e.py
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60

View File

@ -183,8 +183,14 @@ def test_custom_compile_config(
"compilation_mode",
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
)
def test_fp8_kv_scale_compile(compilation_mode: int):
model = "Qwen/Qwen2-0.5B"
@pytest.mark.parametrize(
"model",
[
"Qwen/Qwen2-0.5B", # Standard attention model
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
],
)
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
model_kwargs = {
"quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3",

View File

@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
self.calc_kv_scales(q, kv_c_normed, k_pe)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
return output
else:
# We can still access forward context to check calculation flag
if self.calculate_kv_scales:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
self.calc_kv_scales(q, kv_c_normed, k_pe)
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
if attn_metadata is None or not getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)

View File

@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len
# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
@ -2625,16 +2628,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# Set cudagraph mode to none if calc_kv_scales is true.
if attn_metadata is not None:
metadata_list = (
attn_metadata.values()
if isinstance(attn_metadata, dict)
else [attn_metadata]
)
if any(
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
):
cudagraph_runtime_mode = CUDAGraphMode.NONE
# KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture.
if self.calculate_kv_scales:
cudagraph_runtime_mode = CUDAGraphMode.NONE
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
# Run the model.
# Use persistent buffers for CUDA graphs.