Separate MLAAttention class from Attention (#25103)

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Naveenraj Kamalakannan 2025-10-08 20:11:11 -04:00 committed by GitHub
parent 2a03f93de9
commit e614ab7806
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 502 additions and 163 deletions

View File

@ -6,6 +6,7 @@ from typing import Generic, Optional, Protocol, TypeVar
import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
@ -184,6 +185,31 @@ class AttentionImpl(ABC, Generic[T]):
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
indexer: Optional[object] = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import Callable, Optional
from typing import Callable, Optional, cast
import torch
import torch.nn as nn
@ -10,7 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
@ -23,7 +23,10 @@ from vllm.distributed.kv_transfer import (
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -131,8 +134,6 @@ class Attention(nn.Module, AttentionLayerBase):
quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sparse: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
@ -192,8 +193,6 @@ class Attention(nn.Module, AttentionLayerBase):
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.use_sparse = use_sparse
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
@ -229,9 +228,8 @@ class Attention(nn.Module, AttentionLayerBase):
dtype,
kv_cache_dtype,
block_size,
use_mla=use_mla,
use_mla=False,
has_sink=self.has_sink,
use_sparse=use_sparse,
)
else:
self.attn_backend = attn_backend
@ -349,19 +347,15 @@ class Attention(nn.Module, AttentionLayerBase):
output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
# processed differently.
if not self.use_mla:
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@ -570,6 +564,218 @@ class MultiHeadAttention(nn.Module):
return out.reshape(bsz, q_len, -1)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_sparse: bool = False,
indexer: Optional[object] = None,
):
super().__init__()
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.kv_cache_dtype = kv_cache_dtype
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=self.kv_cache_dtype,
logits_soft_cap=None,
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
)
self.use_direct_call = not current_platform.opaque_attention_op()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
# Align with Attention's scale attributes for MLA backends.
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# Host-side mirrors used by some attention backends
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
self._o_scale_float: Optional[float] = None
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
try:
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
except torch.cuda.OutOfMemoryError:
# Keep defaults if allocation fails; not critical for init.
pass
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
self.calc_kv_scales(q, kv_c_normed, k_pe)
if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
)
return output
else:
# We can still access forward context to check calculation flag
if self.calculate_kv_scales:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
self.calc_kv_scales(q, kv_c_normed, k_pe)
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range = getattr(self, "q_range", torch.tensor(1.0))
k_range = getattr(self, "k_range", torch.tensor(1.0))
v_range = getattr(self, "v_range", torch.tensor(1.0))
self._q_scale.copy_(torch.abs(q).max() / q_range)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max = torch.abs(kv_c_normed).max()
self._k_scale.copy_(kv_abs_max / k_range)
self._v_scale.copy_(kv_abs_max / v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
self.calculate_kv_scales = False
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
@ -726,3 +932,93 @@ direct_register_custom_op(
fake_impl=unified_attention_with_output_fake,
tags=tag_cudagraph_unsafe,
)
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
def unified_mla_attention_with_output(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_mla_attention_with_output_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_mla_attention_with_output",
op_func=unified_mla_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -380,6 +380,8 @@ class CompilationConfig:
_attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_mla_attention",
"vllm.unified_mla_attention_with_output",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch
from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -30,8 +30,9 @@ class MLAModules:
@CustomOp.register("multi_head_latent_attention")
class MultiHeadLatentAttention(CustomOp):
"""MLA layer registered as CustomOp.
class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp to allow OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
@ -87,30 +88,19 @@ class MultiHeadLatentAttention(CustomOp):
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
self.mla_attn = MLAAttention(
num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
use_sparse=self.is_sparse,
indexer=self.indexer,
)

View File

@ -14,6 +14,7 @@ from torch import nn
from typing_extensions import assert_never
from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
@ -122,11 +123,10 @@ def process_weights_after_loading(
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
# Initialize post-load attention weights for both Attention and MLA.
# NOTE: Happens after other modules so we can easily decompress weights.
for _, module in model.named_modules():
if isinstance(module, Attention) and hasattr(
if isinstance(module, (Attention, MLAAttention)) and hasattr(
module, "process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures

View File

@ -58,7 +58,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
@ -1038,7 +1038,7 @@ class DeepseekV2MLAAttention(nn.Module):
topk_indices_buffer=topk_indices_buffer,
)
self.mla_attn = MultiHeadLatentAttention(
self.mla_attn = MultiHeadLatentAttentionWrapper(
self.hidden_size,
self.num_local_heads,
self.scaling,

View File

@ -32,11 +32,11 @@ if TYPE_CHECKING:
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice
@ -408,7 +408,7 @@ def get_per_layer_parameters(
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():

View File

@ -9,11 +9,11 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
@ -880,7 +880,7 @@ class EagleProposer:
def load_model(self, target_model: nn.Module) -> None:
draft_model_config = self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
)
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
@ -897,7 +897,7 @@ class EagleProposer:
)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
- target_attn_layer_names
)
indexer_layers = get_layers_from_vllm_config(

View File

@ -20,6 +20,7 @@ from typing_extensions import TypeAlias
import vllm.envs as envs
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
@ -4388,98 +4389,100 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
use_mla = self.vllm_config.model_config.use_mla
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
if isinstance(attn_module, Attention):
if (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
assert not use_mla, "MLA is not supported for slidingwindow"
kv_cache_spec[layer_name] = SlidingWindowSpec(
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
assert not use_mla, "MLA is not supported for slidingwindow"
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
elif self.attention_chunk_size is not None and isinstance(
attn_module, ChunkedLocalAttention
):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
elif use_mla:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
)
elif self.attention_chunk_size is not None and isinstance(
attn_module, ChunkedLocalAttention
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
)
# encoder-only attention does not need KV cache.
continue
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
elif isinstance(attn_module, MLAAttention):
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
)
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
else:
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0:
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
for layer_name, mamba_module in mamba_layers.items():
elif isinstance(attn_module, MambaBase):
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
shapes=attn_module.get_state_shape(),
dtypes=attn_module.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
mamba_type=attn_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
),
)
ds_indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
)

View File

@ -19,6 +19,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (
@ -32,6 +33,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces import (
@ -63,6 +65,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import (
@ -561,52 +564,71 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
block_size = self.vllm_config.cache_config.block_size
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in layers.items():
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# Classic Attention path
if isinstance(attn_module, Attention):
if (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
if attn_module.attn_type == AttentionType.DECODER:
if isinstance(attn_module, ChunkedLocalAttention):
logger.warning_once(
"Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context."
)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
if attn_module.attn_type == AttentionType.DECODER:
if isinstance(attn_module, ChunkedLocalAttention):
logger.warning_once(
"Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context."
)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
# MLAAttention path
elif isinstance(attn_module, MLAAttention):
if layer_name in kv_cache_spec:
continue
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
)
else:
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
continue
return kv_cache_spec