mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (#15734)
Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
parent
6aae216b4e
commit
a41351f363
@ -237,6 +237,7 @@ class AttentionLayer(Protocol):
|
|||||||
_v_scale: torch.Tensor
|
_v_scale: torch.Tensor
|
||||||
_k_scale_float: float
|
_k_scale_float: float
|
||||||
_v_scale_float: float
|
_v_scale_float: float
|
||||||
|
_prob_scale: torch.Tensor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -766,6 +766,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
query.dtype,
|
query.dtype,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
make_attn_mask=causal_mask) # type: ignore
|
make_attn_mask=causal_mask) # type: ignore
|
||||||
|
use_fp8_scales = (layer._q_scale and layer._k_scale
|
||||||
|
and layer._v_scale and layer._prob_scale
|
||||||
|
and self.kv_cache_dtype == "fp8")
|
||||||
|
full_scales = (
|
||||||
|
layer._q_scale, layer._k_scale, layer._v_scale,
|
||||||
|
layer._prob_scale) if use_fp8_scales else None
|
||||||
self.triton_attn_func(
|
self.triton_attn_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -779,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
self.scale,
|
self.scale,
|
||||||
attn_masks[0][None]
|
attn_masks[0][None]
|
||||||
if attn_masks is not None else None,
|
if attn_masks is not None else None,
|
||||||
|
full_scales,
|
||||||
)
|
)
|
||||||
elif self.use_naive_attn:
|
elif self.use_naive_attn:
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
|||||||
@ -90,6 +90,7 @@ class Attention(nn.Module):
|
|||||||
# FlashAttn doesn't support quantizing the kv-cache only
|
# FlashAttn doesn't support quantizing the kv-cache only
|
||||||
# but requires q to be quantized as well.
|
# but requires q to be quantized as well.
|
||||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
|
||||||
# We also keep the float32 versions of k/v_scale for attention
|
# We also keep the float32 versions of k/v_scale for attention
|
||||||
# backends that don't support tensors (Flashinfer)
|
# backends that don't support tensors (Flashinfer)
|
||||||
|
|||||||
@ -3767,6 +3767,17 @@ class VllmConfig:
|
|||||||
return quant_config
|
return quant_config
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_quantization_config(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||||
|
import copy
|
||||||
|
|
||||||
|
# For some reason, the _ version of this modifies the model_config
|
||||||
|
# object, so using deepcopy to avoid this problem.
|
||||||
|
return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
|
||||||
|
load_config)
|
||||||
|
|
||||||
def with_hf_config(
|
def with_hf_config(
|
||||||
self,
|
self,
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
|
|||||||
@ -1368,6 +1368,23 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
|
load_config = self.create_load_config()
|
||||||
|
quantization_config = VllmConfig.get_quantization_config(
|
||||||
|
model_config, load_config)
|
||||||
|
if isinstance(quantization_config, Fp8Config):
|
||||||
|
_raise_or_fallback(feature_name="fp8 for ROCm",
|
||||||
|
recommend_to_remove=False)
|
||||||
|
return False
|
||||||
|
from vllm.model_executor.layers.quantization.quark.quark import (
|
||||||
|
QuarkConfig)
|
||||||
|
|
||||||
|
if isinstance(quantization_config, QuarkConfig
|
||||||
|
) and quantization_config.has_fp8_layer_weights():
|
||||||
|
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
|
||||||
|
recommend_to_remove=False)
|
||||||
|
|
||||||
# No Fp8 KV cache so far.
|
# No Fp8 KV cache so far.
|
||||||
if self.kv_cache_dtype != "auto":
|
if self.kv_cache_dtype != "auto":
|
||||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||||
|
|||||||
@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig):
|
|||||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||||
|
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||||
|
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||||
|
if name.endswith("self_attn.prob_output_scale"):
|
||||||
|
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||||
|
# If no matches, return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
# Initialize P = softmax(QK^T) scales
|
||||||
|
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
"may cause accuracy issues. Please make sure k/v_scale "
|
"may cause accuracy issues. Please make sure k/v_scale "
|
||||||
"scaling factors are available in the fp8 checkpoint.")
|
"scaling factors are available in the fp8 checkpoint.")
|
||||||
|
|
||||||
|
if layer.q_scale > 0.0:
|
||||||
|
q_scale = layer.q_scale
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
q_scale *= 2
|
||||||
|
layer.calculate_kv_scales = False
|
||||||
|
else:
|
||||||
|
q_scale = 1.0
|
||||||
|
if layer.prob_scale > 0.0:
|
||||||
|
prob_scale = layer.prob_scale
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
prob_scale *= 2
|
||||||
|
else:
|
||||||
|
prob_scale = 1.0
|
||||||
|
|
||||||
|
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
|
||||||
|
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
|
||||||
|
if not is_singleton_float(q_scale) or not is_singleton_float(
|
||||||
|
prob_scale):
|
||||||
|
raise ValueError("Only support per-tensor scaling factor"
|
||||||
|
"for fp8-quantized Q/prob")
|
||||||
|
|
||||||
|
# These are used in the final Attention.forward()
|
||||||
|
layer._q_scale.copy_(q_scale)
|
||||||
|
layer._prob_scale.copy_(prob_scale)
|
||||||
|
if q_scale == 1.0 or prob_scale == 1.0:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Using Q scale {q_scale} and prob scale {prob_scale} "
|
||||||
|
"with fp8 attention. This may cause accuracy issues. "
|
||||||
|
"Please make sure Q/prob scaling factors are "
|
||||||
|
"available in the fp8 checkpoint.")
|
||||||
|
|
||||||
del layer.k_scale
|
del layer.k_scale
|
||||||
del layer.v_scale
|
del layer.v_scale
|
||||||
|
del layer.q_scale
|
||||||
|
del layer.prob_scale
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import re
|
|
||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
for q_config in q_configs:
|
for q_config in q_configs:
|
||||||
q_config["output_tensors"] = None
|
q_config["output_tensors"] = None
|
||||||
|
|
||||||
|
# In case q_proj output is also quantized, remove the configuration
|
||||||
|
# to keep qkv consistency.
|
||||||
|
q_proj_q_config = cast(Dict[str, Any],
|
||||||
|
layer_quant_config.get("*q_proj"))
|
||||||
|
if q_proj_q_config is not None:
|
||||||
|
q_proj_q_config["output_tensors"] = None
|
||||||
|
|
||||||
return cls(quant_config=config,
|
return cls(quant_config=config,
|
||||||
kv_cache_group=kv_cache_group,
|
kv_cache_group=kv_cache_group,
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
@ -289,29 +295,30 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
:param name: param name
|
:param name: param name
|
||||||
:return: matching param name for KV cache scale in vLLM
|
:return: matching param name for KV cache scale in vLLM
|
||||||
"""
|
"""
|
||||||
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
|
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||||
return None
|
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||||
|
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||||
kv_proj_names = [
|
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||||
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
|
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||||
]
|
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||||
if name.endswith(".output_scale"):
|
if name.endswith("self_attn.prob_output_scale"):
|
||||||
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
|
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||||
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
|
|
||||||
return name.replace(kv_output_scale_name, ".attn.k_scale")
|
|
||||||
|
|
||||||
elif len(kv_proj_names) == 2:
|
|
||||||
for kv_proj_name in kv_proj_names:
|
|
||||||
if kv_proj_name in name and kv_proj_name == "k_proj":
|
|
||||||
return name.replace(".k_proj.output_scale",
|
|
||||||
".attn.k_scale")
|
|
||||||
elif kv_proj_name in name and kv_proj_name == "v_proj":
|
|
||||||
return name.replace(".v_proj.output_scale",
|
|
||||||
".attn.v_scale")
|
|
||||||
|
|
||||||
# If no matches, return None
|
# If no matches, return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def has_fp8_layer_weights(self):
|
||||||
|
layer_quant_config = self.quant_config.get("layer_quant_config")
|
||||||
|
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
|
||||||
|
return any([
|
||||||
|
'fp8' in cast(
|
||||||
|
str,
|
||||||
|
to_dict(
|
||||||
|
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
|
||||||
|
"weight")).get("dtype"))
|
||||||
|
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class QuarkLinearMethod(LinearMethodBase):
|
class QuarkLinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user