mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 04:45:02 +08:00
[BugFix][torch.compile] KV scale calculation issues with FP8 quantization (#25513)
Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
9555929e13
commit
c692506e10
@ -139,6 +139,21 @@ def test_custom_compile_config(
|
|||||||
run_model(compilation_config, model, model_kwargs)
|
run_model(compilation_config, model, model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"optimization_level",
|
||||||
|
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
|
||||||
|
)
|
||||||
|
def test_fp8_kv_scale_compile(optimization_level: int):
|
||||||
|
model = "Qwen/Qwen2-0.5B"
|
||||||
|
model_kwargs = {
|
||||||
|
"quantization": "fp8",
|
||||||
|
"kv_cache_dtype": "fp8_e4m3",
|
||||||
|
"calculate_kv_scales": True,
|
||||||
|
"max_model_len": 512,
|
||||||
|
}
|
||||||
|
run_model(optimization_level, model, model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available "
|
pytest.skip("inductor graph partition is only available "
|
||||||
|
|||||||
@ -277,9 +277,8 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||||
"""
|
"""
|
||||||
if self.calculate_kv_scales:
|
if self.calculate_kv_scales:
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
|
||||||
if attn_metadata.enable_kv_scales_calculation:
|
self.layer_name)
|
||||||
self.calc_kv_scales(query, key, value)
|
|
||||||
|
|
||||||
output_dtype = query.dtype
|
output_dtype = query.dtype
|
||||||
if self.query_quant is not None:
|
if self.query_quant is not None:
|
||||||
@ -554,6 +553,44 @@ def maybe_save_kv_layer_to_connector(
|
|||||||
attn_metadata[layer_name])
|
attn_metadata[layer_name])
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_calc_kv_scales(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
|
||||||
|
if attn_metadata is None or not getattr(
|
||||||
|
attn_metadata, 'enable_kv_scales_calculation', False):
|
||||||
|
return
|
||||||
|
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
self.calc_kv_scales(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_calc_kv_scales_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="maybe_calc_kv_scales",
|
||||||
|
op_func=maybe_calc_kv_scales,
|
||||||
|
mutates_args=["query", "key", "value"],
|
||||||
|
fake_impl=maybe_calc_kv_scales_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
|||||||
@ -2351,6 +2351,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.cudagraph_dispatcher.dispatch(batch_descriptor,
|
self.cudagraph_dispatcher.dispatch(batch_descriptor,
|
||||||
use_cascade_attn)
|
use_cascade_attn)
|
||||||
|
|
||||||
|
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||||
|
if attn_metadata is not None:
|
||||||
|
metadata_list = (attn_metadata.values() if isinstance(
|
||||||
|
attn_metadata, dict) else [attn_metadata])
|
||||||
|
if any(
|
||||||
|
getattr(m, 'enable_kv_scales_calculation', False)
|
||||||
|
for m in metadata_list):
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
# This is currently to get around the assert in the DPMetadata
|
# This is currently to get around the assert in the DPMetadata
|
||||||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user