mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 03:25:01 +08:00
446 lines
15 KiB
Python
446 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
from einops import rearrange
|
|
from torch import nn
|
|
|
|
from vllm.attention import AttentionBackend
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
|
from vllm.distributed import (
|
|
divide,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
|
|
|
from .fla.ops.kda import (
|
|
FusedRMSNormGated,
|
|
chunk_kda,
|
|
fused_kda_gate,
|
|
fused_recurrent_kda,
|
|
)
|
|
from .linear import (
|
|
ColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from .mamba.abstract import MambaBase
|
|
from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator
|
|
from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
from .quantization.base_config import QuantizationConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def kda_attention(
|
|
q_proj_states: torch.Tensor,
|
|
k_proj_states: torch.Tensor,
|
|
v_proj_states: torch.Tensor,
|
|
g1: torch.Tensor,
|
|
g2: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
core_attn_out: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
self._forward(
|
|
q_proj_states=q_proj_states,
|
|
k_proj_states=k_proj_states,
|
|
v_proj_states=v_proj_states,
|
|
g1=g1,
|
|
g2=g2,
|
|
beta=beta,
|
|
core_attn_out=core_attn_out,
|
|
)
|
|
|
|
|
|
def kda_attention_fake(
|
|
q_proj_states: torch.Tensor,
|
|
k_proj_states: torch.Tensor,
|
|
v_proj_states: torch.Tensor,
|
|
g1: torch.Tensor,
|
|
g2: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
core_attn_out: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="kda_attention",
|
|
op_func=kda_attention,
|
|
mutates_args=["core_attn_out"],
|
|
fake_impl=kda_attention_fake,
|
|
)
|
|
|
|
|
|
class KimiDeltaAttention(nn.Module, MambaBase):
|
|
@property
|
|
def mamba_type(self) -> str:
|
|
return "linear_attention"
|
|
|
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
|
|
|
|
return GDNAttentionBackend
|
|
|
|
def get_state_dtype(
|
|
self,
|
|
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
|
|
if self.model_config is None or self.cache_config is None:
|
|
raise ValueError("model_config and cache_config must be set")
|
|
return MambaStateDtypeCalculator.kda_state_dtype(
|
|
self.model_config.dtype, self.cache_config.mamba_cache_dtype
|
|
)
|
|
|
|
def get_state_shape(
|
|
self,
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
return MambaStateShapeCalculator.kda_state_shape(
|
|
self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
layer_idx: int,
|
|
hidden_size: int,
|
|
quant_config: QuantizationConfig | None = None,
|
|
cache_config: CacheConfig | None = None,
|
|
model_config: ModelConfig | None = None,
|
|
rms_norm_eps: float = 1e-5,
|
|
prefix: str = "",
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.hidden_size = hidden_size
|
|
self.model_config = model_config
|
|
self.cache_config = cache_config
|
|
if model_config is None:
|
|
raise ValueError("model_config must be provided")
|
|
kda_config = model_config.linear_attn_config
|
|
self.head_dim = kda_config["head_dim"]
|
|
self.num_heads = kda_config["num_heads"]
|
|
self.layer_idx = layer_idx
|
|
self.prefix = prefix
|
|
assert self.num_heads % self.tp_size == 0
|
|
self.local_num_heads = divide(self.num_heads, self.tp_size)
|
|
|
|
projection_size = self.head_dim * self.num_heads
|
|
self.conv_size = kda_config["short_conv_kernel_size"]
|
|
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
projection_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_proj",
|
|
)
|
|
self.k_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
projection_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.k_proj",
|
|
)
|
|
self.v_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
projection_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.v_proj",
|
|
)
|
|
|
|
self.f_a_proj = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.f_a_proj",
|
|
)
|
|
|
|
self.f_b_proj = ColumnParallelLinear(
|
|
self.head_dim,
|
|
projection_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.f_b_proj",
|
|
)
|
|
self.dt_bias = nn.Parameter(
|
|
torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
|
|
)
|
|
|
|
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
|
|
|
|
self.b_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.num_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.b_proj",
|
|
)
|
|
|
|
self.q_conv1d = ColumnParallelLinear(
|
|
input_size=self.conv_size,
|
|
output_size=projection_size,
|
|
bias=False,
|
|
params_dtype=torch.float32,
|
|
prefix=f"{prefix}.q_conv1d",
|
|
)
|
|
self.k_conv1d = ColumnParallelLinear(
|
|
input_size=self.conv_size,
|
|
output_size=projection_size,
|
|
bias=False,
|
|
params_dtype=torch.float32,
|
|
prefix=f"{prefix}.k_conv1d",
|
|
)
|
|
self.v_conv1d = ColumnParallelLinear(
|
|
input_size=self.conv_size,
|
|
output_size=projection_size,
|
|
bias=False,
|
|
params_dtype=torch.float32,
|
|
prefix=f"{prefix}.v_conv1d",
|
|
)
|
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
|
# Can't do this in `weight_loader` since it already exists in
|
|
# `ColumnParallelLinear` and `set_weight_attrs`
|
|
# doesn't allow to override it
|
|
self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
|
|
self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
|
|
self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
|
|
|
|
self.A_log = nn.Parameter(
|
|
torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
|
|
)
|
|
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
|
|
|
|
self.g_a_proj = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.g_a_proj",
|
|
)
|
|
self.g_b_proj = ColumnParallelLinear(
|
|
self.head_dim,
|
|
projection_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.g_b_proj",
|
|
)
|
|
self.o_norm = FusedRMSNormGated(
|
|
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
projection_size,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
compilation_config = get_current_vllm_config().compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
output: torch.Tensor,
|
|
) -> None:
|
|
num_tokens = hidden_states.size(0)
|
|
q = self.q_proj(hidden_states)[0]
|
|
k = self.k_proj(hidden_states)[0]
|
|
v = self.v_proj(hidden_states)[0]
|
|
|
|
beta = self.b_proj(hidden_states)[0].float().sigmoid()
|
|
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
|
|
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
|
beta = beta.unsqueeze(0)
|
|
g1 = g1.unsqueeze(0)
|
|
|
|
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
|
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
|
|
|
core_attn_out = torch.zeros(
|
|
(1, num_tokens, self.local_num_heads, self.head_dim),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
torch.ops.vllm.kda_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
g1,
|
|
g2,
|
|
beta,
|
|
core_attn_out,
|
|
self.prefix,
|
|
)
|
|
core_attn_out = self.o_norm(core_attn_out, g2)
|
|
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
|
output[:] = self.o_proj(core_attn_out)[0]
|
|
|
|
def _forward(
|
|
self,
|
|
q_proj_states: torch.Tensor,
|
|
k_proj_states: torch.Tensor,
|
|
v_proj_states: torch.Tensor,
|
|
g1: torch.Tensor,
|
|
g2: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
core_attn_out: torch.Tensor,
|
|
) -> None:
|
|
forward_context = get_forward_context()
|
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
|
|
|
if attn_metadata is None:
|
|
# # V1 profile run
|
|
return
|
|
|
|
assert isinstance(attn_metadata, dict)
|
|
attn_metadata = attn_metadata[self.prefix]
|
|
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
|
has_initial_state = attn_metadata.has_initial_state
|
|
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
|
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
|
constant_caches = self.kv_cache[forward_context.virtual_engine]
|
|
|
|
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
|
|
# deal with strides
|
|
conv_state_q = conv_state_q.transpose(-1, -2)
|
|
conv_state_k = conv_state_k.transpose(-1, -2)
|
|
conv_state_v = conv_state_v.transpose(-1, -2)
|
|
|
|
q_conv_weights = self.q_conv1d.weight.view(
|
|
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
|
|
)
|
|
k_conv_weights = self.k_conv1d.weight.view(
|
|
self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
|
|
)
|
|
v_conv_weights = self.v_conv1d.weight.view(
|
|
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
|
|
)
|
|
if attn_metadata.num_prefills > 0:
|
|
q_proj_states = q_proj_states.transpose(0, 1)
|
|
k_proj_states = k_proj_states.transpose(0, 1)
|
|
v_proj_states = v_proj_states.transpose(0, 1)
|
|
q = causal_conv1d_fn(
|
|
q_proj_states,
|
|
q_conv_weights,
|
|
self.q_conv1d.bias,
|
|
activation="silu",
|
|
conv_states=conv_state_q,
|
|
has_initial_state=has_initial_state,
|
|
cache_indices=non_spec_state_indices_tensor,
|
|
query_start_loc=non_spec_query_start_loc,
|
|
metadata=attn_metadata,
|
|
).transpose(0, 1)
|
|
k = causal_conv1d_fn(
|
|
k_proj_states,
|
|
k_conv_weights,
|
|
self.k_conv1d.bias,
|
|
activation="silu",
|
|
conv_states=conv_state_k,
|
|
has_initial_state=has_initial_state,
|
|
cache_indices=non_spec_state_indices_tensor,
|
|
query_start_loc=non_spec_query_start_loc,
|
|
metadata=attn_metadata,
|
|
).transpose(0, 1)
|
|
v = causal_conv1d_fn(
|
|
v_proj_states,
|
|
v_conv_weights,
|
|
self.v_conv1d.bias,
|
|
activation="silu",
|
|
conv_states=conv_state_v,
|
|
has_initial_state=has_initial_state,
|
|
cache_indices=non_spec_state_indices_tensor,
|
|
query_start_loc=non_spec_query_start_loc,
|
|
metadata=attn_metadata,
|
|
).transpose(0, 1)
|
|
else:
|
|
decode_conv_indices = non_spec_state_indices_tensor[
|
|
: attn_metadata.num_decodes
|
|
]
|
|
q = causal_conv1d_update(
|
|
q_proj_states,
|
|
conv_state_q,
|
|
q_conv_weights,
|
|
self.q_conv1d.bias,
|
|
activation="silu",
|
|
conv_state_indices=decode_conv_indices,
|
|
validate_data=True,
|
|
)
|
|
k = causal_conv1d_update(
|
|
k_proj_states,
|
|
conv_state_k,
|
|
k_conv_weights,
|
|
self.k_conv1d.bias,
|
|
activation="silu",
|
|
conv_state_indices=decode_conv_indices,
|
|
validate_data=True,
|
|
)
|
|
v = causal_conv1d_update(
|
|
v_proj_states,
|
|
conv_state_v,
|
|
v_conv_weights,
|
|
self.v_conv1d.bias,
|
|
activation="silu",
|
|
conv_state_indices=decode_conv_indices,
|
|
validate_data=True,
|
|
)
|
|
|
|
q, k, v = map(
|
|
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
|
|
)
|
|
|
|
if attn_metadata.num_prefills > 0:
|
|
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
|
|
recurrent_state[zero_idx] = 0
|
|
initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
|
|
(
|
|
core_attn_out_non_spec,
|
|
last_recurrent_state,
|
|
) = chunk_kda(
|
|
q=q,
|
|
k=k,
|
|
v=v,
|
|
g=g1,
|
|
beta=beta,
|
|
initial_state=initial_state,
|
|
output_final_state=True,
|
|
use_qk_l2norm_in_kernel=True,
|
|
cu_seqlens=non_spec_query_start_loc,
|
|
)
|
|
# Init cache
|
|
recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
|
|
else:
|
|
(
|
|
core_attn_out_non_spec,
|
|
last_recurrent_state,
|
|
) = fused_recurrent_kda(
|
|
q=q,
|
|
k=k,
|
|
v=v,
|
|
g=g1,
|
|
beta=beta,
|
|
initial_state=recurrent_state,
|
|
use_qk_l2norm_in_kernel=True,
|
|
cu_seqlens=non_spec_query_start_loc,
|
|
ssm_state_indices=non_spec_state_indices_tensor,
|
|
)
|
|
assert core_attn_out_non_spec.shape == core_attn_out.shape
|
|
core_attn_out[:] = core_attn_out_non_spec
|