mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:45:01 +08:00
[Attention] Add missing kv cache scale setup (#27490)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
4c5f632165
commit
a99564ac5b
@ -123,6 +123,69 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
return attn_backend, flash_attn_varlen_func
|
return attn_backend, flash_attn_varlen_func
|
||||||
|
|
||||||
|
|
||||||
|
def _init_kv_cache_quant(
|
||||||
|
layer: nn.Module,
|
||||||
|
quant_config: QuantizationConfig | None,
|
||||||
|
prefix: str,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
calculate_kv_scales: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes KV cache scaling factors and quantization method.
|
||||||
|
|
||||||
|
This helper function sets up the KV cache quantization attributes that are
|
||||||
|
shared between Attention and MLAAttention layers. It initializes scale
|
||||||
|
tensors for query, key, value, and probability, and configures the
|
||||||
|
quantization method if applicable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The attention layer instance to initialize.
|
||||||
|
quant_config: Optional quantization configuration.
|
||||||
|
prefix: Layer name prefix for quantization method lookup.
|
||||||
|
kv_cache_dtype: The KV cache data type string.
|
||||||
|
calculate_kv_scales: Whether to calculate KV scales dynamically.
|
||||||
|
"""
|
||||||
|
# The default k/v_scale is set to 1.0. This is ignored
|
||||||
|
# when kv-cache is not fp8, and should be used with
|
||||||
|
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||||
|
# expect the pre-quantized k/v_scale to be loaded along
|
||||||
|
# with the model weights.
|
||||||
|
layer.kv_cache_dtype = kv_cache_dtype
|
||||||
|
layer.calculate_kv_scales = calculate_kv_scales
|
||||||
|
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
|
||||||
|
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||||
|
# backends that require the scales to be on host instead of on device.
|
||||||
|
# e.g. Flashinfer
|
||||||
|
layer._q_scale_float = 1.0
|
||||||
|
layer._k_scale_float = 1.0
|
||||||
|
layer._v_scale_float = 1.0
|
||||||
|
|
||||||
|
# The output scale on host memory. This should be the input scale of
|
||||||
|
# the quant op after this attention layer.
|
||||||
|
layer._o_scale_float = None
|
||||||
|
|
||||||
|
quant_method = (
|
||||||
|
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||||
|
)
|
||||||
|
if quant_method is not None and not isinstance(
|
||||||
|
quant_method, UnquantizedLinearMethod
|
||||||
|
):
|
||||||
|
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||||
|
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||||
|
# checkpoint config and become the "auto" behavior
|
||||||
|
if kv_cache_dtype == "fp8_e5m2":
|
||||||
|
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
|
||||||
|
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||||
|
# parameters so that it can be loaded from the model checkpoint.
|
||||||
|
# The k/v_scale will then be converted back to native float32
|
||||||
|
# values after weight loading.
|
||||||
|
layer.quant_method = quant_method
|
||||||
|
layer.quant_method.create_weights(layer)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module, AttentionLayerBase):
|
class Attention(nn.Module, AttentionLayerBase):
|
||||||
"""Attention layer.
|
"""Attention layer.
|
||||||
|
|
||||||
@ -184,30 +247,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
|
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# The default k/v_scale is set to 1.0. This is ignored
|
# Initialize KV cache quantization attributes
|
||||||
# when kv-cache is not fp8, and should be used with
|
_init_kv_cache_quant(
|
||||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||||
# expect the pre-quantized k/v_scale to be loaded along
|
)
|
||||||
# with the model weights.
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
|
||||||
self.calculate_kv_scales = calculate_kv_scales
|
|
||||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
# FlashAttn doesn't support quantizing the kv-cache only
|
|
||||||
# but requires q to be quantized as well.
|
|
||||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
|
|
||||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
|
||||||
# backends that require the scales to be on host instead of on device.
|
|
||||||
# e.g. Flashinfer
|
|
||||||
self._q_scale_float = 1.0
|
|
||||||
self._k_scale_float = 1.0
|
|
||||||
self._v_scale_float = 1.0
|
|
||||||
|
|
||||||
# The output scale on host memory. This should be the input scale of
|
|
||||||
# the quant op after this attention layer.
|
|
||||||
self._o_scale_float: float | None = None
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -215,26 +258,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||||
|
|
||||||
quant_method = (
|
|
||||||
quant_config.get_quant_method(self, prefix=prefix) if quant_config else None
|
|
||||||
)
|
|
||||||
if quant_method is not None and not isinstance(
|
|
||||||
quant_method, UnquantizedLinearMethod
|
|
||||||
):
|
|
||||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
|
||||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
|
||||||
# checkpoint config and become the "auto" behavior
|
|
||||||
if self.kv_cache_dtype == "fp8_e5m2":
|
|
||||||
raise ValueError(
|
|
||||||
"fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
|
|
||||||
)
|
|
||||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
|
||||||
# parameters so that it can be loaded from the model checkpoint.
|
|
||||||
# The k/v_scale will then be converted back to native float32
|
|
||||||
# values after weight loading.
|
|
||||||
self.quant_method = quant_method
|
|
||||||
self.quant_method.create_weights(self)
|
|
||||||
|
|
||||||
# During model initialization, the default dtype is set as the model
|
# During model initialization, the default dtype is set as the model
|
||||||
# weight and activation dtype.
|
# weight and activation dtype.
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
@ -636,7 +659,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
block_size = 16
|
block_size = 16
|
||||||
calculate_kv_scales = False
|
calculate_kv_scales = False
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
|
||||||
|
# Initialize KV cache quantization attributes
|
||||||
|
_init_kv_cache_quant(
|
||||||
|
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||||
|
)
|
||||||
|
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
@ -685,20 +712,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Align with Attention's scale attributes for MLA backends.
|
|
||||||
|
|
||||||
self.calculate_kv_scales = calculate_kv_scales
|
|
||||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Host-side mirrors used by some attention backends
|
|
||||||
self._q_scale_float = 1.0
|
|
||||||
self._k_scale_float = 1.0
|
|
||||||
self._v_scale_float = 1.0
|
|
||||||
self._o_scale_float: float | None = None
|
|
||||||
|
|
||||||
self.use_sparse = use_sparse
|
self.use_sparse = use_sparse
|
||||||
|
|
||||||
# Initialize q/k/v range constants.
|
# Initialize q/k/v range constants.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user