diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 870aa553ca628..f9f146810924e 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -139,6 +139,21 @@ def test_custom_compile_config( run_model(compilation_config, model, model_kwargs) +@pytest.mark.parametrize( + "optimization_level", + [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], +) +def test_fp8_kv_scale_compile(optimization_level: int): + model = "Qwen/Qwen2-0.5B" + model_kwargs = { + "quantization": "fp8", + "kv_cache_dtype": "fp8_e4m3", + "calculate_kv_scales": True, + "max_model_len": 512, + } + run_model(optimization_level, model, model_kwargs) + + def test_inductor_graph_partition_attn_fusion(caplog_vllm): if not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available " diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 326fe6dd048a9..d97c87d96e999 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -277,9 +277,8 @@ class Attention(nn.Module, AttentionLayerBase): `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(query, key, value) + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, + self.layer_name) output_dtype = query.dtype if self.query_quant is not None: @@ -554,6 +553,44 @@ def maybe_save_kv_layer_to_connector( attn_metadata[layer_name]) +def maybe_calc_kv_scales( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> None: + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + + if attn_metadata is None or not getattr( + attn_metadata, 'enable_kv_scales_calculation', False): + return + + self = forward_context.no_compile_layers[layer_name] + self.calc_kv_scales(query, key, value) + + +def maybe_calc_kv_scales_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="maybe_calc_kv_scales", + op_func=maybe_calc_kv_scales, + mutates_args=["query", "key", "value"], + fake_impl=maybe_calc_kv_scales_fake, +) + + def unified_attention( query: torch.Tensor, key: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f8b0b9cba1bc1..9e7d6eb0387bc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2351,6 +2351,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + # Set cudagraph mode to none if calc_kv_scales is true. + if attn_metadata is not None: + metadata_list = (attn_metadata.values() if isinstance( + attn_metadata, dict) else [attn_metadata]) + if any( + getattr(m, 'enable_kv_scales_calculation', False) + for m in metadata_list): + cudagraph_runtime_mode = CUDAGraphMode.NONE + # This is currently to get around the assert in the DPMetadata # where it wants `num_tokens_across_dp` to align with `num_tokens` if ubatch_slices is not None: