mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 20:15:46 +08:00
Separate MLAAttention class from Attention (#25103)
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
2a03f93de9
commit
e614ab7806
@ -6,6 +6,7 @@ from typing import Generic, Optional, Protocol, TypeVar
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
|
|
||||||
|
|
||||||
@ -184,6 +185,31 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[list[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
logits_soft_cap: Optional[float],
|
||||||
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
|
# MLA Specific Arguments
|
||||||
|
q_lora_rank: Optional[int],
|
||||||
|
kv_lora_rank: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
qk_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
kv_b_proj: ColumnParallelLinear,
|
||||||
|
indexer: Optional[object] = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer."""
|
"""Attention layer."""
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionType
|
from vllm.attention import AttentionType
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||||
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
@ -23,7 +23,10 @@ from vllm.distributed.kv_transfer import (
|
|||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
@ -131,8 +134,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
per_layer_sliding_window: Optional[int] = None,
|
per_layer_sliding_window: Optional[int] = None,
|
||||||
use_mla: bool = False,
|
|
||||||
use_sparse: bool = False,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
@ -192,8 +193,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
# the quant op after this attention layer.
|
# the quant op after this attention layer.
|
||||||
self._o_scale_float: Optional[float] = None
|
self._o_scale_float: Optional[float] = None
|
||||||
|
|
||||||
self.use_mla = use_mla
|
|
||||||
self.use_sparse = use_sparse
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
@ -229,9 +228,8 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
dtype,
|
dtype,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla,
|
use_mla=False,
|
||||||
has_sink=self.has_sink,
|
has_sink=self.has_sink,
|
||||||
use_sparse=use_sparse,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
@ -349,19 +347,15 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
output_shape = output_shape if output_shape is not None else query.shape
|
output_shape = output_shape if output_shape is not None else query.shape
|
||||||
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
|
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
|
||||||
hidden_size = output_shape[-1]
|
hidden_size = output_shape[-1]
|
||||||
# We skip reshaping query, key and value tensors for the MLA
|
# Reshape the query, key, and value tensors.
|
||||||
# backend since these tensors have different semantics and are
|
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||||
# processed differently.
|
# CPU overheads from the non-CUDA-graph regions.
|
||||||
if not self.use_mla:
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
# Reshape the query, key, and value tensors.
|
output = output.view(-1, self.num_heads, self.head_size)
|
||||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
if key is not None:
|
||||||
# CPU overheads from the non-CUDA-graph regions.
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
if value is not None:
|
||||||
output = output.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
if key is not None:
|
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
if value is not None:
|
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
@ -570,6 +564,218 @@ class MultiHeadAttention(nn.Module):
|
|||||||
return out.reshape(bsz, q_len, -1)
|
return out.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class MLAAttention(nn.Module, AttentionLayerBase):
|
||||||
|
"""Multi-Head Latent Attention layer.
|
||||||
|
|
||||||
|
This class takes query, and compressed key/value tensors as input.
|
||||||
|
The class does the following:
|
||||||
|
|
||||||
|
1. Store the input key and value tensors in the KV cache.
|
||||||
|
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||||
|
3. Return the output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
scale: float,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
q_lora_rank: Optional[int],
|
||||||
|
kv_lora_rank: int,
|
||||||
|
kv_b_proj: ColumnParallelLinear,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
use_sparse: bool = False,
|
||||||
|
indexer: Optional[object] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.scale = scale
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.head_size = kv_lora_rank + qk_rope_head_dim
|
||||||
|
self.layer_name = prefix
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
calculate_kv_scales = False
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
use_mla=True,
|
||||||
|
use_sparse=use_sparse,
|
||||||
|
)
|
||||||
|
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
|
||||||
|
self.impl = impl_cls(
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
scale=self.scale,
|
||||||
|
num_kv_heads=1,
|
||||||
|
alibi_slopes=None,
|
||||||
|
sliding_window=None,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
logits_soft_cap=None,
|
||||||
|
attn_type=AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name=None,
|
||||||
|
# MLA Args
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
kv_b_proj=kv_b_proj,
|
||||||
|
indexer=indexer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||||
|
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
|
self.kv_cache = [
|
||||||
|
torch.tensor([])
|
||||||
|
for _ in range(
|
||||||
|
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Align with Attention's scale attributes for MLA backends.
|
||||||
|
|
||||||
|
self.calculate_kv_scales = calculate_kv_scales
|
||||||
|
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Host-side mirrors used by some attention backends
|
||||||
|
self._q_scale_float = 1.0
|
||||||
|
self._k_scale_float = 1.0
|
||||||
|
self._v_scale_float = 1.0
|
||||||
|
self._o_scale_float: Optional[float] = None
|
||||||
|
|
||||||
|
self.use_sparse = use_sparse
|
||||||
|
|
||||||
|
# Initialize q/k/v range constants.
|
||||||
|
try:
|
||||||
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
|
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
|
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
# Keep defaults if allocation fails; not critical for init.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
output_shape: Optional[torch.Size] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.use_direct_call:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
|
||||||
|
# Mirror Attention.forward scale calculation path
|
||||||
|
if self.calculate_kv_scales and getattr(
|
||||||
|
attn_metadata, "enable_kv_scales_calculation", False
|
||||||
|
):
|
||||||
|
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||||
|
|
||||||
|
if self.attn_backend.accept_output_buffer:
|
||||||
|
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
|
||||||
|
self.impl.forward(
|
||||||
|
self,
|
||||||
|
q,
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
self_kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output=output,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return self.impl.forward(
|
||||||
|
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.attn_backend.accept_output_buffer:
|
||||||
|
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
|
||||||
|
torch.ops.vllm.unified_mla_attention_with_output(
|
||||||
|
q,
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
output,
|
||||||
|
self.layer_name,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
# We can still access forward context to check calculation flag
|
||||||
|
if self.calculate_kv_scales:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
|
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
|
||||||
|
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||||
|
return torch.ops.vllm.unified_mla_attention(
|
||||||
|
q,
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
self.layer_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
if hasattr(self.impl, "process_weights_after_loading"):
|
||||||
|
self.impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
|
def calc_kv_scales(
|
||||||
|
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Optional scale calculation for MLA inputs.
|
||||||
|
|
||||||
|
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
|
||||||
|
"""
|
||||||
|
# Use safe defaults if ranges are not present
|
||||||
|
q_range = getattr(self, "q_range", torch.tensor(1.0))
|
||||||
|
k_range = getattr(self, "k_range", torch.tensor(1.0))
|
||||||
|
v_range = getattr(self, "v_range", torch.tensor(1.0))
|
||||||
|
|
||||||
|
self._q_scale.copy_(torch.abs(q).max() / q_range)
|
||||||
|
# kv_c_normed is the compressed KV representation; use it for k/v
|
||||||
|
kv_abs_max = torch.abs(kv_c_normed).max()
|
||||||
|
self._k_scale.copy_(kv_abs_max / k_range)
|
||||||
|
self._v_scale.copy_(kv_abs_max / v_range)
|
||||||
|
self._q_scale_float = self._q_scale.item()
|
||||||
|
self._k_scale_float = self._k_scale.item()
|
||||||
|
self._v_scale_float = self._v_scale.item()
|
||||||
|
self.calculate_kv_scales = False
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
|
return self.attn_backend
|
||||||
|
|
||||||
|
|
||||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
return
|
return
|
||||||
@ -726,3 +932,93 @@ direct_register_custom_op(
|
|||||||
fake_impl=unified_attention_with_output_fake,
|
fake_impl=unified_attention_with_output_fake,
|
||||||
tags=tag_cudagraph_unsafe,
|
tags=tag_cudagraph_unsafe,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_mla_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def unified_mla_attention_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(q).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_mla_attention",
|
||||||
|
op_func=unified_mla_attention,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=unified_mla_attention_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_mla_attention_with_output(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
self: MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
self.impl.forward(
|
||||||
|
self,
|
||||||
|
q,
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output=output,
|
||||||
|
output_scale=output_scale,
|
||||||
|
output_block_scale=output_block_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_mla_attention_with_output_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_mla_attention_with_output",
|
||||||
|
op_func=unified_mla_attention_with_output,
|
||||||
|
mutates_args=["output", "output_block_scale"],
|
||||||
|
fake_impl=unified_mla_attention_with_output_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|||||||
@ -380,6 +380,8 @@ class CompilationConfig:
|
|||||||
_attention_ops: ClassVar[list[str]] = [
|
_attention_ops: ClassVar[list[str]] = [
|
||||||
"vllm.unified_attention",
|
"vllm.unified_attention",
|
||||||
"vllm.unified_attention_with_output",
|
"vllm.unified_attention_with_output",
|
||||||
|
"vllm.unified_mla_attention",
|
||||||
|
"vllm.unified_mla_attention_with_output",
|
||||||
"vllm.mamba_mixer2",
|
"vllm.mamba_mixer2",
|
||||||
"vllm.mamba_mixer",
|
"vllm.mamba_mixer",
|
||||||
"vllm.short_conv",
|
"vllm.short_conv",
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention.layer import MLAAttention
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -30,8 +30,9 @@ class MLAModules:
|
|||||||
|
|
||||||
|
|
||||||
@CustomOp.register("multi_head_latent_attention")
|
@CustomOp.register("multi_head_latent_attention")
|
||||||
class MultiHeadLatentAttention(CustomOp):
|
class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||||
"""MLA layer registered as CustomOp.
|
"""MLA layer registered as CustomOp to allow OOT backends to add
|
||||||
|
custom implementations of the outer MLA layer (including rope & o_proj).
|
||||||
Note that currently MLA ignores the enable/disable mechanism of CustomOp
|
Note that currently MLA ignores the enable/disable mechanism of CustomOp
|
||||||
because there is only one in-tree implementation in forward_native.
|
because there is only one in-tree implementation in forward_native.
|
||||||
TODO: implement this with a new PluggableLayer mechanism.
|
TODO: implement this with a new PluggableLayer mechanism.
|
||||||
@ -87,30 +88,19 @@ class MultiHeadLatentAttention(CustomOp):
|
|||||||
self.topk_tokens = self.indexer.topk_tokens
|
self.topk_tokens = self.indexer.topk_tokens
|
||||||
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
||||||
|
|
||||||
# In the MLA backend, kv_cache includes both k_c and
|
self.mla_attn = MLAAttention(
|
||||||
# pe (i.e. decoupled position embeddings). In particular,
|
|
||||||
# the concat_and_cache_mla op requires
|
|
||||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
|
||||||
# i.e.
|
|
||||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
|
||||||
self.mla_attn = Attention(
|
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
|
||||||
scale=scale,
|
scale=scale,
|
||||||
num_kv_heads=1,
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_mla=True,
|
|
||||||
use_sparse=mla_modules.is_sparse,
|
|
||||||
# MLA Args
|
|
||||||
q_lora_rank=self.q_lora_rank,
|
|
||||||
kv_lora_rank=self.kv_lora_rank,
|
|
||||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
||||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
||||||
qk_head_dim=self.qk_head_dim,
|
|
||||||
v_head_dim=self.v_head_dim,
|
|
||||||
kv_b_proj=self.kv_b_proj,
|
kv_b_proj=self.kv_b_proj,
|
||||||
|
use_sparse=self.is_sparse,
|
||||||
indexer=self.indexer,
|
indexer=self.indexer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from torch import nn
|
|||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
|
from vllm.attention.layer import MLAAttention
|
||||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||||
@ -122,11 +123,10 @@ def process_weights_after_loading(
|
|||||||
with device_loading_context(module, target_device):
|
with device_loading_context(module, target_device):
|
||||||
quant_method.process_weights_after_loading(module)
|
quant_method.process_weights_after_loading(module)
|
||||||
|
|
||||||
# Currently only used by MLA.
|
# Initialize post-load attention weights for both Attention and MLA.
|
||||||
# NOTE: This intentionally happens after other modules so we can easily
|
# NOTE: Happens after other modules so we can easily decompress weights.
|
||||||
# decompress the weights for MLA.
|
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if isinstance(module, Attention) and hasattr(
|
if isinstance(module, (Attention, MLAAttention)) and hasattr(
|
||||||
module, "process_weights_after_loading"
|
module, "process_weights_after_loading"
|
||||||
):
|
):
|
||||||
# TODO(lucas): see if there is a way to unify the signatures
|
# TODO(lucas): see if there is a way to unify the signatures
|
||||||
|
|||||||
@ -58,7 +58,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
|
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
@ -1038,7 +1038,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
topk_indices_buffer=topk_indices_buffer,
|
topk_indices_buffer=topk_indices_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mla_attn = MultiHeadLatentAttention(
|
self.mla_attn = MultiHeadLatentAttentionWrapper(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.num_local_heads,
|
self.num_local_heads,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
|
|||||||
@ -32,11 +32,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
get_kv_connector_cache_layout,
|
get_kv_connector_cache_layout,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||||
|
|
||||||
@ -408,7 +408,7 @@ def get_per_layer_parameters(
|
|||||||
to use during `plan`.
|
to use during `plan`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
|
layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
|
||||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||||
|
|
||||||
for key, layer in layers.items():
|
for key, layer in layers.items():
|
||||||
|
|||||||
@ -9,11 +9,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
|
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import supports_multimodal
|
from vllm.model_executor.models import supports_multimodal
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||||
@ -880,7 +880,7 @@ class EagleProposer:
|
|||||||
def load_model(self, target_model: nn.Module) -> None:
|
def load_model(self, target_model: nn.Module) -> None:
|
||||||
draft_model_config = self.vllm_config.speculative_config.draft_model_config
|
draft_model_config = self.vllm_config.speculative_config.draft_model_config
|
||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
|
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
|
||||||
)
|
)
|
||||||
# FIXME: support hybrid kv for draft model
|
# FIXME: support hybrid kv for draft model
|
||||||
target_indexer_layer_names = set(
|
target_indexer_layer_names = set(
|
||||||
@ -897,7 +897,7 @@ class EagleProposer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
draft_attn_layer_names = (
|
draft_attn_layer_names = (
|
||||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
|
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
|
||||||
- target_attn_layer_names
|
- target_attn_layer_names
|
||||||
)
|
)
|
||||||
indexer_layers = get_layers_from_vllm_config(
|
indexer_layers = get_layers_from_vllm_config(
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from typing_extensions import TypeAlias
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.layer import MLAAttention
|
||||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
@ -4388,98 +4389,100 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
use_mla = self.vllm_config.model_config.use_mla
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||||
for layer_name, attn_module in attn_layers.items():
|
for layer_name, attn_module in attn_layers.items():
|
||||||
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None:
|
if isinstance(attn_module, Attention):
|
||||||
# The layer doesn't need its own KV cache and will use that of
|
if (
|
||||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
|
||||||
# that KV cache management logic will act as this layer does
|
) is not None:
|
||||||
# not exist, and doesn't allocate KV cache for the layer. This
|
# The layer doesn't need its own KV cache and will use that of
|
||||||
# enables the memory saving of cross-layer kv sharing, allowing
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||||
# a given amount of memory to accommodate longer context lengths
|
# that KV cache management logic will act as this layer does
|
||||||
# or enable more requests to be processed simultaneously.
|
# not exist, and doesn't allocate KV cache for the layer. This
|
||||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
# enables the memory saving of cross-layer kv sharing, allowing
|
||||||
continue
|
# a given amount of memory to accommodate longer context lengths
|
||||||
|
# or enable more requests to be processed simultaneously.
|
||||||
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||||
|
continue
|
||||||
|
|
||||||
# TODO(lucas): move the attention specs into the model layers like
|
# TODO(lucas): move the attention specs into the model layers like
|
||||||
# the attention backends
|
# the attention backends
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
if attn_module.sliding_window is not None:
|
if attn_module.sliding_window is not None:
|
||||||
assert not use_mla, "MLA is not supported for slidingwindow"
|
assert not use_mla, "MLA is not supported for slidingwindow"
|
||||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
sliding_window=attn_module.sliding_window,
|
||||||
|
)
|
||||||
|
elif self.attention_chunk_size is not None and isinstance(
|
||||||
|
attn_module, ChunkedLocalAttention
|
||||||
|
):
|
||||||
|
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
attention_chunk_size=self.attention_chunk_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
)
|
||||||
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
sliding_window=attn_module.sliding_window,
|
|
||||||
)
|
)
|
||||||
elif use_mla:
|
elif attn_module.attn_type in (
|
||||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
AttentionType.ENCODER,
|
||||||
block_size=block_size,
|
AttentionType.ENCODER_ONLY,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
cache_dtype_str=cache_dtype_str,
|
|
||||||
)
|
|
||||||
elif self.attention_chunk_size is not None and isinstance(
|
|
||||||
attn_module, ChunkedLocalAttention
|
|
||||||
):
|
):
|
||||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
# encoder-only attention does not need KV cache.
|
||||||
block_size=block_size,
|
continue
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
attention_chunk_size=self.attention_chunk_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
elif isinstance(attn_module, MLAAttention):
|
||||||
head_size=attn_module.head_size,
|
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
)
|
|
||||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=1,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
|
cache_dtype_str=cache_dtype_str,
|
||||||
)
|
)
|
||||||
elif attn_module.attn_type in (
|
|
||||||
AttentionType.ENCODER,
|
|
||||||
AttentionType.ENCODER_ONLY,
|
|
||||||
):
|
|
||||||
# encoder-only attention does not need KV cache.
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
|
|
||||||
|
|
||||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
elif isinstance(attn_module, MambaBase):
|
||||||
if len(mamba_layers) > 0:
|
if (
|
||||||
if (
|
self.vllm_config.speculative_config is not None
|
||||||
self.vllm_config.speculative_config is not None
|
and self.vllm_config.model_config.hf_config.model_type
|
||||||
and self.vllm_config.model_config.hf_config.model_type
|
not in ["qwen3_next"]
|
||||||
not in ["qwen3_next"]
|
):
|
||||||
):
|
raise NotImplementedError(
|
||||||
raise NotImplementedError(
|
"Mamba with speculative decoding is not supported yet."
|
||||||
"Mamba with speculative decoding is not supported yet."
|
)
|
||||||
)
|
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
|
||||||
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
|
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
|
||||||
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
|
|
||||||
|
|
||||||
for layer_name, mamba_module in mamba_layers.items():
|
|
||||||
kv_cache_spec[layer_name] = MambaSpec(
|
kv_cache_spec[layer_name] = MambaSpec(
|
||||||
shapes=mamba_module.get_state_shape(),
|
shapes=attn_module.get_state_shape(),
|
||||||
dtypes=mamba_module.get_state_dtype(),
|
dtypes=attn_module.get_state_dtype(),
|
||||||
block_size=mamba_block_size,
|
block_size=mamba_block_size,
|
||||||
page_size_padded=page_size_padded,
|
page_size_padded=page_size_padded,
|
||||||
mamba_type=mamba_module.mamba_type,
|
mamba_type=attn_module.mamba_type,
|
||||||
num_speculative_blocks=(
|
num_speculative_blocks=(
|
||||||
self.speculative_config.num_speculative_tokens
|
self.speculative_config.num_speculative_tokens
|
||||||
if self.speculative_config
|
if self.speculative_config
|
||||||
else 0
|
else 0
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
ds_indexer_layers = get_layers_from_vllm_config(
|
ds_indexer_layers = get_layers_from_vllm_config(
|
||||||
self.vllm_config, DeepseekV32IndexerCache
|
self.vllm_config, DeepseekV32IndexerCache
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import torch_xla.runtime as xr
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm.attention.layer import MLAAttention
|
||||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
@ -32,6 +33,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
|||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import BaseLayerWithLoRA
|
from vllm.lora.layers import BaseLayerWithLoRA
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||||
from vllm.model_executor.models.interfaces import (
|
from vllm.model_executor.models.interfaces import (
|
||||||
@ -63,6 +65,7 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
FullAttentionSpec,
|
FullAttentionSpec,
|
||||||
KVCacheConfig,
|
KVCacheConfig,
|
||||||
KVCacheSpec,
|
KVCacheSpec,
|
||||||
|
MLAAttentionSpec,
|
||||||
SlidingWindowSpec,
|
SlidingWindowSpec,
|
||||||
)
|
)
|
||||||
from vllm.v1.outputs import (
|
from vllm.v1.outputs import (
|
||||||
@ -561,52 +564,71 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
format. Layers that do not need KV cache are not included.
|
format. Layers that do not need KV cache are not included.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
|
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
||||||
|
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
for layer_name, attn_module in layers.items():
|
for layer_name, attn_module in layers.items():
|
||||||
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None:
|
# Classic Attention path
|
||||||
# The layer doesn't need its own KV cache and will use that of
|
if isinstance(attn_module, Attention):
|
||||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
if (
|
||||||
# that KV cache management logic will act as this layer does
|
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
|
||||||
# not exist, and doesn't allocate KV cache for the layer. This
|
) is not None:
|
||||||
# enables the memory saving of cross-layer kv sharing, allowing
|
# The layer doesn't need its own KV cache and will use that of
|
||||||
# a given amount of memory to accommodate longer context lengths
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||||
# or enable more requests to be processed simultaneously.
|
# that KV cache management logic will act as this layer does
|
||||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
# not exist, and doesn't allocate KV cache for the layer. This
|
||||||
continue
|
# enables the memory saving of cross-layer kv sharing, allowing
|
||||||
|
# a given amount of memory to accommodate longer context lengths
|
||||||
|
# or enable more requests to be processed simultaneously.
|
||||||
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||||
|
continue
|
||||||
|
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
if isinstance(attn_module, ChunkedLocalAttention):
|
if isinstance(attn_module, ChunkedLocalAttention):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using irope in Pallas is not supported yet, it "
|
"Using irope in Pallas is not supported yet, it "
|
||||||
"will fall back to global attention for long context."
|
"will fall back to global attention for long context."
|
||||||
)
|
)
|
||||||
if attn_module.sliding_window is not None:
|
if attn_module.sliding_window is not None:
|
||||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
sliding_window=attn_module.sliding_window,
|
sliding_window=attn_module.sliding_window,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
)
|
||||||
|
elif attn_module.attn_type in (
|
||||||
|
AttentionType.ENCODER,
|
||||||
|
AttentionType.ENCODER_ONLY,
|
||||||
|
):
|
||||||
|
# encoder-only attention does not need KV cache.
|
||||||
|
continue
|
||||||
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
|
||||||
block_size=block_size,
|
# MLAAttention path
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
elif isinstance(attn_module, MLAAttention):
|
||||||
head_size=attn_module.head_size,
|
if layer_name in kv_cache_spec:
|
||||||
dtype=self.kv_cache_dtype,
|
continue
|
||||||
)
|
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||||
elif attn_module.attn_type in (
|
block_size=block_size,
|
||||||
AttentionType.ENCODER,
|
num_kv_heads=1,
|
||||||
AttentionType.ENCODER_ONLY,
|
head_size=attn_module.head_size,
|
||||||
):
|
dtype=self.kv_cache_dtype,
|
||||||
# encoder-only attention does not need KV cache.
|
cache_dtype_str=cache_dtype_str,
|
||||||
continue
|
)
|
||||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
|
continue
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user