mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 10:24:28 +08:00
[Bugfix] Ensure calculated KV scales are applied in attention. (#27232)
Signed-off-by: adabeyta <aabeyta@redhat.com>
This commit is contained in:
parent
b30372cbd0
commit
a5a790eea6
@ -471,8 +471,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
# Limit to no custom ops to reduce running time
|
||||
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
|
||||
# Limit to no custom ops to reduce running time
|
||||
# Wrap with quotes to escape yaml and avoid starting -k string with a -
|
||||
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
|
||||
|
||||
@ -951,10 +951,13 @@ steps:
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- tests/compile/test_fusions_e2e.py
|
||||
- tests/compile/test_full_graph.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
# Run all e2e fusion tests
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py
|
||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
|
||||
@ -183,8 +183,14 @@ def test_custom_compile_config(
|
||||
"compilation_mode",
|
||||
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(compilation_mode: int):
|
||||
model = "Qwen/Qwen2-0.5B"
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"Qwen/Qwen2-0.5B", # Standard attention model
|
||||
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
|
||||
],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
"kv_cache_dtype": "fp8_e4m3",
|
||||
|
||||
@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
k_pe: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
|
||||
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# Mirror Attention.forward scale calculation path
|
||||
if self.calculate_kv_scales and getattr(
|
||||
attn_metadata, "enable_kv_scales_calculation", False
|
||||
):
|
||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
self.impl.forward(
|
||||
@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
return output
|
||||
else:
|
||||
# We can still access forward context to check calculation flag
|
||||
if self.calculate_kv_scales:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
|
||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||
return torch.ops.vllm.unified_mla_attention(
|
||||
q,
|
||||
kv_c_normed,
|
||||
@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
if attn_metadata is None or not getattr(
|
||||
attn_metadata, "enable_kv_scales_calculation", False
|
||||
):
|
||||
# Only calculate if the layer's calculate_kv_scales flag is True
|
||||
# This flag gets set to False after the first forward pass
|
||||
if not self.calculate_kv_scales:
|
||||
return
|
||||
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
|
||||
|
||||
@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# This will be overridden in load_model()
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
# Always set to false after the first forward pass
|
||||
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
|
||||
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
@ -2625,16 +2628,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||
if attn_metadata is not None:
|
||||
metadata_list = (
|
||||
attn_metadata.values()
|
||||
if isinstance(attn_metadata, dict)
|
||||
else [attn_metadata]
|
||||
)
|
||||
if any(
|
||||
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
|
||||
):
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# KV scales calculation involves dynamic operations that are incompatible
|
||||
# with CUDA graph capture.
|
||||
if self.calculate_kv_scales:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# Mark KV scales as calculated after the first forward pass
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user