Separate MLAAttention class from Attention (#25103)

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Naveenraj Kamalakannan 2025-10-08 20:11:11 -04:00 committed by GitHub
parent 2a03f93de9
commit e614ab7806
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 502 additions and 163 deletions

View File

@ -6,6 +6,7 @@ from typing import Generic, Optional, Protocol, TypeVar
import torch import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
@ -184,6 +185,31 @@ class AttentionImpl(ABC, Generic[T]):
class MLAAttentionImpl(AttentionImpl[T], Generic[T]): class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
indexer: Optional[object] = None,
) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def forward( def forward(
self, self,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
from typing import Callable, Optional from typing import Callable, Optional, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,7 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
@ -23,7 +23,10 @@ from vllm.distributed.kv_transfer import (
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -131,8 +134,6 @@ class Attention(nn.Module, AttentionLayerBase):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None, per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sparse: bool = False,
prefix: str = "", prefix: str = "",
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
@ -192,8 +193,6 @@ class Attention(nn.Module, AttentionLayerBase):
# the quant op after this attention layer. # the quant op after this attention layer.
self._o_scale_float: Optional[float] = None self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.use_sparse = use_sparse
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
@ -229,9 +228,8 @@ class Attention(nn.Module, AttentionLayerBase):
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size, block_size,
use_mla=use_mla, use_mla=False,
has_sink=self.has_sink, has_sink=self.has_sink,
use_sparse=use_sparse,
) )
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend
@ -349,19 +347,15 @@ class Attention(nn.Module, AttentionLayerBase):
output_shape = output_shape if output_shape is not None else query.shape output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device) output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1] hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA # Reshape the query, key, and value tensors.
# backend since these tensors have different semantics and are # NOTE(woosuk): We do this outside the custom op to minimize the
# processed differently. # CPU overheads from the non-CUDA-graph regions.
if not self.use_mla: query = query.view(-1, self.num_heads, self.head_size)
# Reshape the query, key, and value tensors. output = output.view(-1, self.num_heads, self.head_size)
# NOTE(woosuk): We do this outside the custom op to minimize the if key is not None:
# CPU overheads from the non-CUDA-graph regions. key = key.view(-1, self.num_kv_heads, self.head_size)
query = query.view(-1, self.num_heads, self.head_size) if value is not None:
output = output.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
@ -570,6 +564,218 @@ class MultiHeadAttention(nn.Module):
return out.reshape(bsz, q_len, -1) return out.reshape(bsz, q_len, -1)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_sparse: bool = False,
indexer: Optional[object] = None,
):
super().__init__()
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.kv_cache_dtype = kv_cache_dtype
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=self.kv_cache_dtype,
logits_soft_cap=None,
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
)
self.use_direct_call = not current_platform.opaque_attention_op()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
# Align with Attention's scale attributes for MLA backends.
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# Host-side mirrors used by some attention backends
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
self._o_scale_float: Optional[float] = None
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
try:
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
except torch.cuda.OutOfMemoryError:
# Keep defaults if allocation fails; not critical for init.
pass
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
self.calc_kv_scales(q, kv_c_normed, k_pe)
if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
)
return output
else:
# We can still access forward context to check calculation flag
if self.calculate_kv_scales:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
self.calc_kv_scales(q, kv_c_normed, k_pe)
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range = getattr(self, "q_range", torch.tensor(1.0))
k_range = getattr(self, "k_range", torch.tensor(1.0))
v_range = getattr(self, "v_range", torch.tensor(1.0))
self._q_scale.copy_(torch.abs(q).max() / q_range)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max = torch.abs(kv_c_normed).max()
self._k_scale.copy_(kv_abs_max / k_range)
self._v_scale.copy_(kv_abs_max / v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
self.calculate_kv_scales = False
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def wait_for_kv_layer_from_connector(layer_name: str): def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return return
@ -726,3 +932,93 @@ direct_register_custom_op(
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
tags=tag_cudagraph_unsafe, tags=tag_cudagraph_unsafe,
) )
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
def unified_mla_attention_with_output(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_mla_attention_with_output_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_mla_attention_with_output",
op_func=unified_mla_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -380,6 +380,8 @@ class CompilationConfig:
_attention_ops: ClassVar[list[str]] = [ _attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention", "vllm.unified_attention",
"vllm.unified_attention_with_output", "vllm.unified_attention_with_output",
"vllm.unified_mla_attention",
"vllm.unified_mla_attention_with_output",
"vllm.mamba_mixer2", "vllm.mamba_mixer2",
"vllm.mamba_mixer", "vllm.mamba_mixer",
"vllm.short_conv", "vllm.short_conv",

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch import torch
from vllm.attention import Attention from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -30,8 +30,9 @@ class MLAModules:
@CustomOp.register("multi_head_latent_attention") @CustomOp.register("multi_head_latent_attention")
class MultiHeadLatentAttention(CustomOp): class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp. """MLA layer registered as CustomOp to allow OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently MLA ignores the enable/disable mechanism of CustomOp Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native. because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism. TODO: implement this with a new PluggableLayer mechanism.
@ -87,30 +88,19 @@ class MultiHeadLatentAttention(CustomOp):
self.topk_tokens = self.indexer.topk_tokens self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and self.mla_attn = MLAAttention(
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_heads, num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale, scale=scale,
num_kv_heads=1, qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj, kv_b_proj=self.kv_b_proj,
use_sparse=self.is_sparse,
indexer=self.indexer, indexer=self.indexer,
) )

View File

@ -14,6 +14,7 @@ from torch import nn
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.attention import Attention from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.linear import QKVCrossParallelLinear
@ -122,11 +123,10 @@ def process_weights_after_loading(
with device_loading_context(module, target_device): with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
# Currently only used by MLA. # Initialize post-load attention weights for both Attention and MLA.
# NOTE: This intentionally happens after other modules so we can easily # NOTE: Happens after other modules so we can easily decompress weights.
# decompress the weights for MLA.
for _, module in model.named_modules(): for _, module in model.named_modules():
if isinstance(module, Attention) and hasattr( if isinstance(module, (Attention, MLAAttention)) and hasattr(
module, "process_weights_after_loading" module, "process_weights_after_loading"
): ):
# TODO(lucas): see if there is a way to unify the signatures # TODO(lucas): see if there is a way to unify the signatures

View File

@ -58,7 +58,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
@ -1038,7 +1038,7 @@ class DeepseekV2MLAAttention(nn.Module):
topk_indices_buffer=topk_indices_buffer, topk_indices_buffer=topk_indices_buffer,
) )
self.mla_attn = MultiHeadLatentAttention( self.mla_attn = MultiHeadLatentAttentionWrapper(
self.hidden_size, self.hidden_size,
self.num_local_heads, self.num_local_heads,
self.scaling, self.scaling,

View File

@ -32,11 +32,11 @@ if TYPE_CHECKING:
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout, get_kv_connector_cache_layout,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice from vllm.v1.worker.ubatch_utils import UBatchSlice
@ -408,7 +408,7 @@ def get_per_layer_parameters(
to use during `plan`. to use during `plan`.
""" """
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names) layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {} per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items(): for key, layer in layers.items():

View File

@ -9,11 +9,11 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
@ -880,7 +880,7 @@ class EagleProposer:
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
draft_model_config = self.vllm_config.speculative_config.draft_model_config draft_model_config = self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys() get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
) )
# FIXME: support hybrid kv for draft model # FIXME: support hybrid kv for draft model
target_indexer_layer_names = set( target_indexer_layer_names = set(
@ -897,7 +897,7 @@ class EagleProposer:
) )
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
- target_attn_layer_names - target_attn_layer_names
) )
indexer_layers = get_layers_from_vllm_config( indexer_layers = get_layers_from_vllm_config(

View File

@ -20,6 +20,7 @@ from typing_extensions import TypeAlias
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.cuda_graph import CUDAGraphWrapper
@ -4388,98 +4389,100 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
cache_dtype_str = self.vllm_config.cache_config.cache_dtype cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items(): for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: if isinstance(attn_module, Attention):
# The layer doesn't need its own KV cache and will use that of if (
# the target layer. We skip creating a KVCacheSpec for it, so kv_tgt_layer := attn_module.kv_sharing_target_layer_name
# that KV cache management logic will act as this layer does ) is not None:
# not exist, and doesn't allocate KV cache for the layer. This # The layer doesn't need its own KV cache and will use that of
# enables the memory saving of cross-layer kv sharing, allowing # the target layer. We skip creating a KVCacheSpec for it, so
# a given amount of memory to accommodate longer context lengths # that KV cache management logic will act as this layer does
# or enable more requests to be processed simultaneously. # not exist, and doesn't allocate KV cache for the layer. This
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer # enables the memory saving of cross-layer kv sharing, allowing
continue # a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# TODO(lucas): move the attention specs into the model layers like # TODO(lucas): move the attention specs into the model layers like
# the attention backends # the attention backends
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
assert not use_mla, "MLA is not supported for slidingwindow" assert not use_mla, "MLA is not supported for slidingwindow"
kv_cache_spec[layer_name] = SlidingWindowSpec( kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
elif self.attention_chunk_size is not None and isinstance(
attn_module, ChunkedLocalAttention
):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
) )
elif use_mla: elif attn_module.attn_type in (
kv_cache_spec[layer_name] = MLAAttentionSpec( AttentionType.ENCODER,
block_size=block_size, AttentionType.ENCODER_ONLY,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
)
elif self.attention_chunk_size is not None and isinstance(
attn_module, ChunkedLocalAttention
): ):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( # encoder-only attention does not need KV cache.
block_size=block_size, continue
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
)
else: else:
kv_cache_spec[layer_name] = FullAttentionSpec( raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, elif isinstance(attn_module, MLAAttention):
head_size=attn_module.head_size, kv_cache_spec[layer_name] = MLAAttentionSpec(
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=1,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
) )
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
else:
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) elif isinstance(attn_module, MambaBase):
if len(mamba_layers) > 0: if (
if ( self.vllm_config.speculative_config is not None
self.vllm_config.speculative_config is not None and self.vllm_config.model_config.hf_config.model_type
and self.vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
not in ["qwen3_next"] ):
): raise NotImplementedError(
raise NotImplementedError( "Mamba with speculative decoding is not supported yet."
"Mamba with speculative decoding is not supported yet." )
) mamba_block_size = self.vllm_config.cache_config.mamba_block_size
mamba_block_size = self.vllm_config.cache_config.mamba_block_size page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec( kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(), shapes=attn_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(), dtypes=attn_module.get_state_dtype(),
block_size=mamba_block_size, block_size=mamba_block_size,
page_size_padded=page_size_padded, page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type, mamba_type=attn_module.mamba_type,
num_speculative_blocks=( num_speculative_blocks=(
self.speculative_config.num_speculative_tokens self.speculative_config.num_speculative_tokens
if self.speculative_config if self.speculative_config
else 0 else 0
), ),
) )
ds_indexer_layers = get_layers_from_vllm_config( ds_indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache self.vllm_config, DeepseekV32IndexerCache
) )

View File

@ -19,6 +19,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import Attention from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import ( from vllm.config import (
@ -32,6 +33,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
@ -63,6 +65,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
KVCacheSpec, KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec, SlidingWindowSpec,
) )
from vllm.v1.outputs import ( from vllm.v1.outputs import (
@ -561,52 +564,71 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
layers = get_layers_from_vllm_config(self.vllm_config, Attention) layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in layers.items(): for layer_name, attn_module in layers.items():
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # Classic Attention path
# The layer doesn't need its own KV cache and will use that of if isinstance(attn_module, Attention):
# the target layer. We skip creating a KVCacheSpec for it, so if (
# that KV cache management logic will act as this layer does kv_tgt_layer := attn_module.kv_sharing_target_layer_name
# not exist, and doesn't allocate KV cache for the layer. This ) is not None:
# enables the memory saving of cross-layer kv sharing, allowing # The layer doesn't need its own KV cache and will use that of
# a given amount of memory to accommodate longer context lengths # the target layer. We skip creating a KVCacheSpec for it, so
# or enable more requests to be processed simultaneously. # that KV cache management logic will act as this layer does
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer # not exist, and doesn't allocate KV cache for the layer. This
continue # enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if isinstance(attn_module, ChunkedLocalAttention): if isinstance(attn_module, ChunkedLocalAttention):
logger.warning_once( logger.warning_once(
"Using irope in Pallas is not supported yet, it " "Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context." "will fall back to global attention for long context."
) )
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec( kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window,
) )
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else: else:
kv_cache_spec[layer_name] = FullAttentionSpec( raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
block_size=block_size, # MLAAttention path
num_kv_heads=attn_module.num_kv_heads, elif isinstance(attn_module, MLAAttention):
head_size=attn_module.head_size, if layer_name in kv_cache_spec:
dtype=self.kv_cache_dtype, continue
) kv_cache_spec[layer_name] = MLAAttentionSpec(
elif attn_module.attn_type in ( block_size=block_size,
AttentionType.ENCODER, num_kv_heads=1,
AttentionType.ENCODER_ONLY, head_size=attn_module.head_size,
): dtype=self.kv_cache_dtype,
# encoder-only attention does not need KV cache. cache_dtype_str=cache_dtype_str,
continue )
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else: else:
raise ValueError(f"Unknown attention type: {attn_module.attn_type}") continue
return kv_cache_spec return kv_cache_spec