mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 18:31:25 +08:00
[v1] - Mamba1 Attention Metadata (#21249)
Signed-off-by: asafg <asafg@ai21.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
parent
31f09c615f
commit
46a13949d5
@ -45,6 +45,9 @@ struct SSMParamsBase {
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
index_t ssm_states_batch_stride;
|
||||
index_t ssm_states_dim_stride;
|
||||
index_t ssm_states_dstate_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
|
||||
@ -132,8 +132,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
|
||||
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
|
||||
cache_index * params.ssm_states_batch_stride +
|
||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
@ -248,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
// Initialize running total
|
||||
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
|
||||
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
@ -259,7 +261,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
if (chunk == n_chunks - 1) {
|
||||
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
|
||||
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@ -481,6 +483,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.out_batch_stride = out.stride(1);
|
||||
params.out_d_stride = out.stride(0);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
|
||||
}
|
||||
else{
|
||||
if (!is_variable_B) {
|
||||
@ -509,6 +515,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -370,9 +370,9 @@ th {
|
||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
|
||||
| **Decoder-only Models** | <nobr>🚀 Optimized</nobr> |
|
||||
| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> |
|
||||
| **Embedding Models** | <nobr>🟢 Functional</nobr> |
|
||||
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> |
|
||||
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> |
|
||||
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |
|
||||
|
||||
vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol.
|
||||
@ -104,13 +104,11 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
|
||||
#### Mamba Models
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
|
||||
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
|
||||
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
|
||||
disabling prefix caching in V1.
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
||||
|
||||
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
|
||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.
|
||||
|
||||
#### Encoder-Decoder Models
|
||||
|
||||
@ -53,6 +53,8 @@ HF_UNSUPPORTED_MODELS = [
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
"ibm-ai-platform/Bamba-9B-v1",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
|
||||
@ -12,7 +12,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
UNSUPPORTED_MODELS_V1 = [
|
||||
"openai/whisper-large-v3", # transcription
|
||||
"facebook/bart-large-cnn", # encoder decoder
|
||||
"state-spaces/mamba-130m-hf", # mamba1
|
||||
]
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@ -1,30 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm import envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
@CustomOp.register("mamba_mixer")
|
||||
class MambaMixer(CustomOp):
|
||||
class MambaMixer(MambaBase, CustomOp):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
@ -47,13 +54,16 @@ class MambaMixer(CustomOp):
|
||||
rms_norm_has_weight: bool = True,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False):
|
||||
is_lora_enabled: bool = False,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.time_step_rank = time_step_rank
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.use_rms_norm = use_rms_norm
|
||||
self.activation = activation
|
||||
self.is_lora_enabled = is_lora_enabled
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
@ -131,14 +141,62 @@ class MambaMixer(CustomOp):
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
if not envs.VLLM_USE_V1:
|
||||
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
||||
else:
|
||||
return self.forward_cuda(hidden_states, mamba_cache_params)
|
||||
|
||||
def forward_native(self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
pass
|
||||
|
||||
def forward_cuda(self, hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams):
|
||||
def forward_cuda(self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_state = mamba1_metadata.has_initial_states
|
||||
context_lens_tensor = mamba1_metadata.context_lens_tensor
|
||||
else:
|
||||
assert mamba_cache_params is not None
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
||||
|
||||
if context_lens_tensor is not None:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@ -148,8 +206,12 @@ class MambaMixer(CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states = hidden_states.contiguous()
|
||||
return self.out_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
if query_start_loc is not None and context_lens_tensor is not None:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
@ -161,18 +223,18 @@ class MambaMixer(CustomOp):
|
||||
conv_weights,
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=state_indices_tensor,
|
||||
query_start_loc=query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
mamba_cache_params.conv_state,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
conv_state_indices=state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
@ -203,11 +265,10 @@ class MambaMixer(CustomOp):
|
||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
||||
self.dt_proj, "bias") else None)
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
if query_start_loc is not None and context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
mamba_cache_params.ssm_state,
|
||||
ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
@ -216,24 +277,23 @@ class MambaMixer(CustomOp):
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
cache_indices=state_indices_tensor,
|
||||
has_initial_state=has_initial_state,
|
||||
query_start_loc=query_start_loc)
|
||||
else:
|
||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
selective_state_update(ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
@ -245,3 +305,15 @@ class MambaMixer(CustomOp):
|
||||
contextualized_states = self.out_proj(
|
||||
scan_outputs.transpose(-2, -1))[0]
|
||||
return contextualized_states
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
intermediate_size=self.intermediate_size,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba1"
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
extra_groups_for_head_shards, get_mamba_state_shape)
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
@ -278,8 +278,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
self.n_groups = n_groups + extra_groups_for_head_shards(
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
@ -732,7 +733,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=self.intermediate_size,
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
n_groups=self.n_groups,
|
||||
|
||||
@ -3,53 +3,70 @@
|
||||
from vllm.distributed import divide
|
||||
|
||||
|
||||
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), conv_kernel - 1)
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
temporal_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), state_size)
|
||||
|
||||
# In V0, the conv_state shape was swapped during allocation in
|
||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
||||
# calculation level
|
||||
if use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
def get_mamba_state_shape(
|
||||
intermediate_size: int,
|
||||
tp_world_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
""" Get the shape of mamba state."""
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (n_groups +
|
||||
extra_groups_for_head_shards(n_groups, tp_world_size))
|
||||
@classmethod
|
||||
def mamba2_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(
|
||||
n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (intermediate_size + 2 * n_groups * state_size)
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (
|
||||
conv_kernel - 1,
|
||||
divide(conv_dim, tp_world_size),
|
||||
)
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
if not use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
if not use_v1:
|
||||
conv_state_shape = (conv_state_shape[1], conv_state_shape[0])
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads,
|
||||
tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(num_heads, tp_world_size),
|
||||
head_dim,
|
||||
state_size,
|
||||
)
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
@ -25,7 +25,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -457,7 +458,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.mamba_n_groups,
|
||||
|
||||
@ -24,7 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -543,7 +544,7 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
if hf_config.mamba_d_ssm is None else
|
||||
hf_config.mamba_d_ssm)
|
||||
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.mamba_n_groups,
|
||||
|
||||
@ -23,7 +23,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -547,7 +548,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.mamba_n_groups,
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import JambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -19,6 +20,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -32,8 +35,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -112,7 +114,8 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
use_rms_norm=True,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
is_lora_enabled = self.is_lora_enabled
|
||||
is_lora_enabled = self.is_lora_enabled,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
@ -344,7 +347,8 @@ class JambaModel(nn.Module):
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||
kv_cache_index += 1
|
||||
if isinstance(layer, JambaMambaDecoderLayer):
|
||||
if isinstance(layer,
|
||||
JambaMambaDecoderLayer) and mamba_cache_params:
|
||||
current_state_layer = mamba_cache_index
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
@ -442,7 +446,7 @@ class JambaModel(nn.Module):
|
||||
|
||||
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid, SupportsV0Only):
|
||||
IsHybrid):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||
".self_attn.": ".",
|
||||
".A_log": ".A"
|
||||
@ -509,14 +513,19 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
# NOTE: mamba_cache_params is not needed for v1
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_layers, *state_shape)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
@ -529,19 +538,22 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
hidden_size = self.config.hidden_size
|
||||
conv_state_shape = (
|
||||
self.config.mamba_expand * hidden_size // world_size,
|
||||
self.config.mamba_d_conv - 1,
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hidden_size = hf_config.hidden_size
|
||||
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
intermediate_size=hf_config.mamba_expand * hidden_size,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
self.config.mamba_expand * hidden_size // world_size,
|
||||
self.config.mamba_d_state,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
|
||||
@ -8,20 +8,21 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree, SupportsPP,
|
||||
SupportsV0Only)
|
||||
IsAttentionFree, SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -41,7 +42,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
config: MambaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_lora_enabled: Optional[bool] = False) -> None:
|
||||
is_lora_enabled: Optional[bool] = False,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
||||
@ -58,7 +60,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
rms_norm_has_weight=not self.is_falcon_mamba,
|
||||
rms_norm_eps=mixer_rms_eps,
|
||||
activation=config.hidden_act,
|
||||
is_lora_enabled=self.is_lora_enabled)
|
||||
is_lora_enabled=self.is_lora_enabled,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@ -107,7 +110,8 @@ class MambaModel(nn.Module):
|
||||
lambda prefix: MambaDecoderLayer(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
is_lora_enabled=is_lora_enabled),
|
||||
is_lora_enabled=is_lora_enabled,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
@ -123,7 +127,7 @@ class MambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -140,12 +144,17 @@ class MambaModel(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
|
||||
layer_cache_params = None
|
||||
if mamba_cache_params is not None:
|
||||
layer_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer))
|
||||
mamba_cache_params=layer_cache_params)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -176,8 +185,7 @@ class MambaModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
|
||||
SupportsV0Only):
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -227,20 +235,40 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_layers, *state_shape)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
intermediate_size=hf_config.intermediate_size,
|
||||
state_size=hf_config.state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
use_v1=envs.VLLM_USE_V1)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
@ -248,19 +276,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
conv_state_shape = (
|
||||
self.config.intermediate_size // world_size,
|
||||
self.config.conv_kernel - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
self.config.intermediate_size // world_size,
|
||||
self.config.state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
|
||||
@ -19,7 +19,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -220,7 +221,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.n_groups,
|
||||
|
||||
@ -39,7 +39,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -482,7 +483,7 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.n_groups,
|
||||
|
||||
@ -32,7 +32,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -869,7 +870,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.mamba_ngroups,
|
||||
|
||||
67
vllm/v1/attention/backends/mamba1_attn.py
Normal file
67
vllm/v1/attention/backends/mamba1_attn.py
Normal file
@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
|
||||
return Mamba1AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
context_lens_tensor: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states: torch.Tensor
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
layer_names: list[str],
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
query_start_loc.device)
|
||||
has_initial_states = (context_lens_tensor > 0)
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
has_initial_states=has_initial_states,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
@ -1,10 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||
|
||||
|
||||
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
||||
if mamba_type == "mamba1":
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
if mamba_type == "mamba2":
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user