vllm/vllm/attention/layer.py
Matthew Bonanni a99564ac5b
[Attention] Add missing kv cache scale setup (#27490)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2025-10-25 00:12:49 -07:00

1071 lines
38 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from collections.abc import Callable
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
return USE_XFORMERS_OPS
if current_platform.is_cuda() and current_platform.has_device_capability(100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
try:
from importlib.util import find_spec
find_spec("xformers.ops")
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
# the warning only needs to be shown once
if not USE_XFORMERS_OPS:
logger.warning("Xformers is not available, falling back.")
return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
if (
dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and current_platform.has_device_capability(80)
):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec
return find_spec("flash_attn") is not None
return False
def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend,
use_upstream_fa: bool,
attn_backend_override: _Backend | None = None,
) -> tuple[_Backend, Callable]:
if (
attn_backend != _Backend.FLASH_ATTN
and attn_backend != _Backend.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
return attn_backend, flash_attn_varlen_func
def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
prefix: str,
kv_cache_dtype: str,
calculate_kv_scales: bool,
) -> None:
"""Initializes KV cache scaling factors and quantization method.
This helper function sets up the KV cache quantization attributes that are
shared between Attention and MLAAttention layers. It initializes scale
tensors for query, key, value, and probability, and configures the
quantization method if applicable.
Args:
layer: The attention layer instance to initialize.
quant_config: Optional quantization configuration.
prefix: Layer name prefix for quantization method lookup.
kv_cache_dtype: The KV cache data type string.
calculate_kv_scales: Whether to calculate KV scales dynamically.
"""
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = calculate_kv_scales
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
layer._o_scale_float = None
quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod
):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
layer.quant_method = quant_method
layer.quant_method.create_weights(layer)
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
super().__init__()
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
vllm_config = get_current_vllm_config()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, (
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
)
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
if attn_backend is None:
self.attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=False,
has_sink=self.has_sink,
)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**extra_impl_args,
)
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
# for attn backends supporting query quantization
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input()
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input():
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
)
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}"
return s
def process_weights_after_loading(self, act_dtype: torch.dtype):
self.impl.process_weights_after_loading(act_dtype)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
)
return SlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
else:
return FullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Determine the attention backend
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
if current_platform.is_xpu():
# currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA
else:
self.attn_backend = (
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
self.attn_backend = _Backend.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}
# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
f"use_upstream_fa: {use_upstream_fa}"
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.is_flash_attn_backend:
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)
cu_seqlens_k = torch.arange(
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
)
out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} backend yet."
)
return out.reshape(bsz, q_len, -1)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
):
super().__init__()
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=self.kv_cache_dtype,
logits_soft_cap=None,
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
**extra_impl_args,
)
self.use_direct_call = not current_platform.opaque_attention_op()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
self.calc_kv_scales(q, kv_c_normed, k_pe)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
)
return output
else:
# We can still access forward context to check calculation flag
if self.calculate_kv_scales:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
self.calc_kv_scales(q, kv_c_normed, k_pe)
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range = getattr(self, "q_range", torch.tensor(1.0))
k_range = getattr(self, "k_range", torch.tensor(1.0))
v_range = getattr(self, "v_range", torch.tensor(1.0))
self._q_scale.copy_(torch.abs(q).max() / q_range)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max = torch.abs(kv_c_normed).max()
self._k_scale.copy_(kv_abs_max / k_range)
self._v_scale.copy_(kv_abs_max / v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
self.calculate_kv_scales = False
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: list[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
if attn_metadata is None or not getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=["query", "key", "value"],
fake_impl=maybe_calc_kv_scales_fake,
)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
fake_impl=unified_attention_fake,
)
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
)
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
def unified_mla_attention_with_output(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_mla_attention_with_output_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_mla_attention_with_output",
op_func=unified_mla_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)