From 65b1f121c885f169da210946eddb0d52524677f1 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 25 Jul 2024 12:46:15 -0400 Subject: [PATCH] [Bugfix] Fix `kv_cache_dtype=fp8` without scales for FP8 checkpoints (#6761) --- tests/quantization/test_fp8.py | 12 ++++++++++-- vllm/model_executor/layers/quantization/kv_cache.py | 6 ++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 0602fedf0b8e3..ad92f1f189f65 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -60,12 +60,20 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") -def test_load_fp16_model(vllm_runner) -> None: - with vllm_runner("facebook/opt-125m", quantization="fp8") as llm: +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None: + with vllm_runner("facebook/opt-125m", + quantization="fp8", + kv_cache_dtype=kv_cache_dtype) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 fc1 = model.model.decoder.layers[0].fc1 assert isinstance(fc1.quant_method, Fp8LinearMethod) + if kv_cache_dtype == "fp8": + attn = model.model.decoder.layers[0].self_attn.attn + assert isinstance(attn.quant_method, Fp8KVCacheMethod) + assert attn._k_scale == 1.0 + assert attn._v_scale == 1.0 capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index c1495711447fa..d79536d196b92 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -46,10 +46,8 @@ class BaseKVCacheMethod(QuantizeMethodBase): elif layer.k_scale < 0.0 and layer.v_scale < 0.0: # If no scales were loaded (both scales are invalid negative # values), use the default value of 1.0 - k_scale = torch.nn.Parameter(torch.tensor(1.0), - requires_grad=False) - v_scale = torch.nn.Parameter(torch.tensor(1.0), - requires_grad=False) + k_scale = 1.0 + v_scale = 1.0 else: # If we find a single kv_scale in the checkpoint, we remap # kv_scale to k_scale during weight loading, and duplicate