mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 13:45:45 +08:00
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: isotr0py <2037008807@qq.com>
265 lines
8.5 KiB
Python
265 lines
8.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
|
MambaStateDtypeCalculator,
|
|
MambaStateShapeCalculator,
|
|
)
|
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|
causal_conv1d_fn,
|
|
causal_conv1d_update,
|
|
)
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
|
|
|
|
|
|
@CustomOp.register("short_conv")
|
|
class ShortConv(MambaBase, CustomOp):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
dim: int,
|
|
layer_idx: int,
|
|
model_config: ModelConfig | None = None,
|
|
cache_config: CacheConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.conv_dim = dim
|
|
self.L_cache = config.conv_L_cache
|
|
self.bias = config.conv_bias
|
|
|
|
self.conv = ColumnParallelLinear(
|
|
input_size=self.L_cache,
|
|
output_size=dim,
|
|
bias=self.bias,
|
|
prefix=f"{prefix}.conv1d",
|
|
)
|
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
|
# Can't do this in `weight_loader` since it already exists in
|
|
# `ColumnParallelLinear` and `set_weight_attrs`
|
|
# doesn't allow to override it
|
|
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)
|
|
|
|
self.in_proj = MergedColumnParallelLinear(
|
|
input_size=dim,
|
|
output_sizes=[dim] * 3,
|
|
bias=self.bias,
|
|
prefix=f"{prefix}.in_proj",
|
|
)
|
|
self.out_proj = RowParallelLinear(
|
|
input_size=dim,
|
|
output_size=dim,
|
|
bias=self.bias,
|
|
prefix=f"{prefix}.out_proj",
|
|
)
|
|
|
|
compilation_config = get_current_vllm_config().compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
self.kv_cache = (torch.tensor([]),)
|
|
|
|
self.model_config = model_config
|
|
self.cache_config = cache_config
|
|
self.prefix = prefix
|
|
|
|
def forward_native(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
):
|
|
return
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
):
|
|
torch.ops.vllm.short_conv(
|
|
hidden_states,
|
|
output,
|
|
self.prefix,
|
|
)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
):
|
|
forward_context = get_forward_context()
|
|
# ShortConvAttentionMetadata contains metadata necessary for the
|
|
# short_conv triton kernels to operate in continuous batching and in
|
|
# chunked prefill modes; they are computed at top-level model forward
|
|
# since they stay the same and reused for all mamba layers in the same
|
|
# iteration.
|
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
|
if attn_metadata is not None:
|
|
assert isinstance(attn_metadata, dict)
|
|
attn_metadata = attn_metadata[self.prefix]
|
|
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
|
|
|
BCx, _ = self.in_proj(hidden_states)
|
|
|
|
B, C, x = BCx.chunk(3, dim=-1)
|
|
|
|
conv_weights = self.conv.weight.view(
|
|
self.conv.weight.size(0), self.conv.weight.size(2)
|
|
)
|
|
|
|
if attn_metadata is None:
|
|
# V1 profile run
|
|
Bx = (B * x).contiguous()
|
|
hidden_states = C * Bx
|
|
contextualized_states, _ = self.out_proj(hidden_states)
|
|
return contextualized_states
|
|
|
|
num_prefills = attn_metadata.num_prefills # request count
|
|
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
|
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
|
has_prefill = num_prefills > 0
|
|
has_decode = num_decodes > 0
|
|
num_actual_tokens = num_decodes + num_prefill_tokens
|
|
|
|
# NOTE: V1 puts decode before prefill
|
|
# Separate prefill and decode by splitting varlen input
|
|
# Split along token dimension
|
|
B_d, B_p = torch.split(
|
|
B[:num_actual_tokens],
|
|
[num_decodes, num_prefill_tokens],
|
|
dim=0,
|
|
)
|
|
C_d, C_p = torch.split(
|
|
C[:num_actual_tokens],
|
|
[num_decodes, num_prefill_tokens],
|
|
dim=0,
|
|
)
|
|
x_d, x_p = torch.split(
|
|
x[:num_actual_tokens],
|
|
[num_decodes, num_prefill_tokens],
|
|
dim=0,
|
|
)
|
|
# Split along batch dimension
|
|
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
|
state_indices_tensor,
|
|
[num_decodes, num_prefills],
|
|
dim=0,
|
|
)
|
|
query_start_loc_p = (
|
|
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
|
|
if has_prefill
|
|
else None
|
|
)
|
|
|
|
conv_output_list = []
|
|
|
|
if has_prefill:
|
|
Bx_p = (B_p * x_p).transpose(0, 1)
|
|
Bx = causal_conv1d_fn(
|
|
Bx_p,
|
|
conv_weights,
|
|
self.conv.bias,
|
|
activation=None,
|
|
conv_states=conv_state,
|
|
has_initial_state=has_initial_states_p,
|
|
cache_indices=state_indices_tensor_p,
|
|
metadata=attn_metadata,
|
|
query_start_loc=query_start_loc_p,
|
|
).transpose(0, 1)[:num_prefill_tokens]
|
|
|
|
y = C_p * Bx
|
|
conv_output_list.append(y)
|
|
|
|
if has_decode:
|
|
Bx_d = (B_d * x_d).contiguous()
|
|
Bx = causal_conv1d_update(
|
|
Bx_d,
|
|
conv_state,
|
|
conv_weights,
|
|
self.conv.bias,
|
|
activation=None,
|
|
conv_state_indices=state_indices_tensor_d,
|
|
)
|
|
y = C_d * Bx
|
|
conv_output_list.insert(0, y)
|
|
|
|
# Merge prefill and decode outputs before passing to gated MLP
|
|
hidden_states = torch.vstack(conv_output_list)
|
|
|
|
# Final linear projection
|
|
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
|
|
|
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
|
assert self.model_config is not None
|
|
assert self.cache_config is not None
|
|
return MambaStateDtypeCalculator.short_conv_state_dtype(
|
|
self.model_config.dtype,
|
|
self.cache_config.mamba_cache_dtype,
|
|
)
|
|
|
|
def get_state_shape(self) -> tuple[tuple[int, ...]]:
|
|
return MambaStateShapeCalculator.short_conv_state_shape(
|
|
tp_world_size=get_tensor_model_parallel_world_size(),
|
|
intermediate_size=self.conv_dim,
|
|
conv_kernel=self.L_cache,
|
|
)
|
|
|
|
@property
|
|
def mamba_type(self) -> str:
|
|
return "short_conv"
|
|
|
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
|
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
|
|
|
return ShortConvAttentionBackend
|
|
|
|
|
|
def short_conv(
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
self.forward_cuda(hidden_states=hidden_states, output=output)
|
|
|
|
|
|
def short_conv_fake(
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="short_conv",
|
|
op_func=short_conv,
|
|
mutates_args=["output"],
|
|
fake_impl=short_conv_fake,
|
|
)
|