mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 02:35:24 +08:00
[Attention] Add missing kv cache scale setup (#27490)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
4c5f632165
commit
a99564ac5b
@ -123,6 +123,69 @@ def maybe_get_vit_flash_attn_backend(
|
||||
return attn_backend, flash_attn_varlen_func
|
||||
|
||||
|
||||
def _init_kv_cache_quant(
|
||||
layer: nn.Module,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
kv_cache_dtype: str,
|
||||
calculate_kv_scales: bool,
|
||||
) -> None:
|
||||
"""Initializes KV cache scaling factors and quantization method.
|
||||
|
||||
This helper function sets up the KV cache quantization attributes that are
|
||||
shared between Attention and MLAAttention layers. It initializes scale
|
||||
tensors for query, key, value, and probability, and configures the
|
||||
quantization method if applicable.
|
||||
|
||||
Args:
|
||||
layer: The attention layer instance to initialize.
|
||||
quant_config: Optional quantization configuration.
|
||||
prefix: Layer name prefix for quantization method lookup.
|
||||
kv_cache_dtype: The KV cache data type string.
|
||||
calculate_kv_scales: Whether to calculate KV scales dynamically.
|
||||
"""
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
layer.kv_cache_dtype = kv_cache_dtype
|
||||
layer.calculate_kv_scales = calculate_kv_scales
|
||||
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
layer._o_scale_float = None
|
||||
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
layer.quant_method = quant_method
|
||||
layer.quant_method.create_weights(layer)
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
@ -184,30 +247,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
self._o_scale_float: float | None = None
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -215,26 +258,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(self, prefix=prefix) if quant_config else None
|
||||
)
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError(
|
||||
"fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
|
||||
)
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
@ -636,7 +659,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||
)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.attn_backend = get_attn_backend(
|
||||
@ -685,20 +712,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
]
|
||||
|
||||
# Align with Attention's scale attributes for MLA backends.
|
||||
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# Host-side mirrors used by some attention backends
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
self._o_scale_float: float | None = None
|
||||
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user