mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 14:05:39 +08:00
306 lines
12 KiB
Python
306 lines
12 KiB
Python
"""Attention layer."""
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from vllm.attention import AttentionMetadata, AttentionType
|
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
|
from vllm.config import CacheConfig, get_current_vllm_config
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig)
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
from vllm.platforms import _Backend, current_platform
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""Attention layer.
|
|
|
|
This class takes query, key, and value tensors as input. The input tensors
|
|
can either contain prompt tokens or generation tokens.
|
|
The class does the following:
|
|
|
|
1. Store the input key and value tensors in the KV cache.
|
|
2. Perform (multi-head/multi-query/grouped-query) attention.
|
|
3. Return the output tensor.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: Optional[int] = None,
|
|
alibi_slopes: Optional[List[float]] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
logits_soft_cap: Optional[float] = None,
|
|
per_layer_sliding_window: Optional[int] = None,
|
|
prefix: str = "",
|
|
attn_type: str = AttentionType.DECODER,
|
|
) -> None:
|
|
super().__init__()
|
|
if per_layer_sliding_window is not None:
|
|
# per-layer sliding window
|
|
sliding_window = per_layer_sliding_window
|
|
elif cache_config is not None:
|
|
# model-level sliding window
|
|
sliding_window = cache_config.sliding_window
|
|
else:
|
|
sliding_window = None
|
|
|
|
if cache_config is not None:
|
|
kv_cache_dtype = cache_config.cache_dtype
|
|
block_size = cache_config.block_size
|
|
is_attention_free = cache_config.is_attention_free
|
|
else:
|
|
kv_cache_dtype = "auto"
|
|
block_size = 16
|
|
is_attention_free = False
|
|
if num_kv_heads is None:
|
|
num_kv_heads = num_heads
|
|
|
|
# The default k/v_scale is set to 1.0. This is ignored
|
|
# when kv-cache is not fp8, and should be used with
|
|
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
|
# expect the pre-quantized k/v_scale to be loaded along
|
|
# with the model weights.
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self._k_scale = 1.0
|
|
self._v_scale = 1.0
|
|
quant_method = quant_config.get_quant_method(
|
|
self, prefix=prefix) if quant_config else None
|
|
if quant_method is not None:
|
|
assert isinstance(quant_method, BaseKVCacheMethod)
|
|
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
|
# checkpoint config and become the "auto" behavior
|
|
if self.kv_cache_dtype == "fp8_e5m2":
|
|
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
|
"fp8 checkpoints.")
|
|
# If quantization is enabled, we make "k_scale" and "v_scale"
|
|
# parameters so that it can be loaded from the model checkpoint.
|
|
# The k/v_scale will then be converted back to native float32
|
|
# values after weight loading.
|
|
self.quant_method = quant_method
|
|
self.quant_method.create_weights(self)
|
|
|
|
# During model initialization, the default dtype is set as the model
|
|
# weight and activation dtype.
|
|
dtype = torch.get_default_dtype()
|
|
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
|
|
block_size, is_attention_free,
|
|
blocksparse_params is not None)
|
|
impl_cls = attn_backend.get_impl_cls()
|
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
|
blocksparse_params, logits_soft_cap, attn_type)
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.num_kv_heads = num_kv_heads
|
|
self.backend = backend_name_to_enum(attn_backend.get_name())
|
|
|
|
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
|
# torch.compile works by registering the attention as one giant
|
|
# opaque custom op. For other platforms, we directly call them
|
|
# and let torch.compile handle them.
|
|
self.use_direct_call = not current_platform.is_cuda_alike(
|
|
) and not current_platform.is_cpu()
|
|
|
|
# For some attention backends, we allocate an output tensor before
|
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
|
# makes sure the output tensor is allocated inside the cudagraph.
|
|
self.use_output = self.backend == _Backend.FLASH_ATTN or \
|
|
self.backend == _Backend.FLASH_ATTN_VLLM_V1
|
|
compilation_config = get_current_vllm_config().compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
self.layer_name = prefix
|
|
self.attn_type = attn_type
|
|
# use a placeholder kv cache tensor during init, which will be replaced
|
|
# by bind_kv_cache
|
|
# this variable will not be accessed if use_direct_call is True
|
|
self.kv_cache = [
|
|
torch.tensor([]) for _ in range(get_current_vllm_config(
|
|
).parallel_config.pipeline_parallel_size)
|
|
]
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
|
|
if self.use_direct_call:
|
|
return self.impl.forward(query, key, value, kv_cache,
|
|
attn_metadata, self._k_scale,
|
|
self._v_scale)
|
|
elif self.use_output:
|
|
output = torch.empty_like(query)
|
|
hidden_size = query.size(-1)
|
|
# Reshape the query, key, and value tensors.
|
|
# NOTE(woosuk): We do this outside the custom op to minimize the
|
|
# CPU overheads from the non-CUDA-graph regions.
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
output = output.view(-1, self.num_heads, self.head_size)
|
|
if key is not None:
|
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
if value is not None:
|
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
torch.ops.vllm.unified_attention_with_output(
|
|
query, key, value, output, self.layer_name)
|
|
return output.view(-1, hidden_size)
|
|
else:
|
|
return torch.ops.vllm.unified_attention(query, key, value,
|
|
self.layer_name)
|
|
|
|
def extra_repr(self) -> str:
|
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
|
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
|
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
|
s += f", scale={self.impl.scale}" # type: ignore
|
|
s += f", backend={self.impl.__class__.__name__}"
|
|
return s
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
"""Multi-headed attention without any cache, used for ViT."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = scale
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
|
|
dtype = torch.get_default_dtype()
|
|
attn_backend = get_attn_backend(head_size,
|
|
dtype,
|
|
kv_cache_dtype=None,
|
|
block_size=16,
|
|
is_attention_free=False)
|
|
attn_backend = backend_name_to_enum(attn_backend.get_name())
|
|
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
|
attn_backend = _Backend.XFORMERS
|
|
|
|
self.attn_backend = attn_backend if attn_backend in {
|
|
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
|
} else _Backend.TORCH_SDPA
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Input shape: batch_size x seq_len x hidden_size"""
|
|
# TODO(Isotr0py): Use existing backend implementations and support FA2
|
|
bsz, q_len, _ = query.size()
|
|
kv_len = key.size(1)
|
|
|
|
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
|
|
|
if self.attn_backend == _Backend.XFORMERS:
|
|
from xformers import ops as xops
|
|
|
|
out = xops.memory_efficient_attention_forward(query,
|
|
key,
|
|
value,
|
|
scale=self.scale)
|
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
|
query, key, value = (x.transpose(1, 2)
|
|
for x in (query, key, value))
|
|
out = F.scaled_dot_product_attention(query,
|
|
key,
|
|
value,
|
|
scale=self.scale)
|
|
out = out.transpose(1, 2)
|
|
return out.view(bsz, q_len, -1)
|
|
|
|
|
|
def unified_attention(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
layer_name: str,
|
|
) -> torch.Tensor:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
self = forward_context.attn_layers[layer_name]
|
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
|
self._k_scale, self._v_scale)
|
|
|
|
|
|
def unified_attention_fake(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
layer_name: str,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(query).contiguous()
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="unified_attention",
|
|
op_func=unified_attention,
|
|
mutates_args=[],
|
|
fake_impl=unified_attention_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|
|
|
|
|
|
def unified_attention_with_output(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
output: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
self = forward_context.attn_layers[layer_name]
|
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
self.impl.forward(query,
|
|
key,
|
|
value,
|
|
kv_cache,
|
|
attn_metadata,
|
|
self._k_scale,
|
|
self._v_scale,
|
|
output=output)
|
|
|
|
|
|
def unified_attention_with_output_fake(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
output: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="unified_attention_with_output",
|
|
op_func=unified_attention_with_output,
|
|
mutates_args=["output"],
|
|
fake_impl=unified_attention_with_output_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|