[V1] v1 engine + full CUDA graph support for PLaMo2 (#23998)

Signed-off-by: Hemmi Shinichi <shemmi@preferred.jp>
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
Co-authored-by: Hemmi Shinichi <shemmi@preferred.jp>
Co-authored-by: Thomas Parnell <tom.parnell@gmail.com>
This commit is contained in:
nopperl 2025-09-04 00:24:02 +09:00 committed by GitHub
parent 6d80ae83e1
commit fa4311d85f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 349 additions and 125 deletions

View File

@ -395,7 +395,7 @@ th {
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -110,7 +110,7 @@ Models using selective state-space mechanisms instead of standard transformer at
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported.
Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`).
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`, `Plamo2ForCausalLM`).
Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`).

View File

@ -25,8 +25,7 @@ SSM_MODELS = [
HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
# skipping until vLLM implementation issues are resolved
# "pfnet/plamo-2-1b",
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
@ -37,6 +36,7 @@ HYBRID_MODELS = [
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
@ -47,6 +47,7 @@ V1_SUPPORTED_MODELS = [
FULL_CUDA_GRAPH_MODELS = [
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
]

View File

@ -287,8 +287,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.53",
transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53",

View File

@ -340,6 +340,7 @@ class CompilationConfig:
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
]
def compute_hash(self) -> str:

View File

@ -3,19 +3,24 @@
"""Inference-only PLaMo2 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -23,8 +28,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@ -39,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsPP, SupportsV0Only)
SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
@ -47,8 +55,10 @@ from vllm.model_executor.models.utils import (
make_layers, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from vllm.utils import LayerBlockType, direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Only used for type hinting.
@ -73,20 +83,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore
vocab_size: int
class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
def _init_weights(self, module: torch.nn.Module) -> None:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def is_mamba(config: Plamo2Config, i: int) -> bool:
assert config.mamba_step > 1
@ -99,7 +95,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# Adapted from:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer
class Plamo2MambaMixer(nn.Module):
@CustomOp.register(name="plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp):
def __init__(self,
vllm_config: VllmConfig,
@ -108,6 +105,8 @@ class Plamo2MambaMixer(nn.Module):
**kwargs) -> None:
super().__init__()
self.config = vllm_config.model_config.hf_config
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.quant_config = vllm_config.quant_config
self.hidden_size = self.config.hidden_size
self.ssm_state_size = self.config.mamba_d_state
@ -115,8 +114,6 @@ class Plamo2MambaMixer(nn.Module):
self.intermediate_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head)
self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size_per_tp_worker = \
self.intermediate_size // self.tp_size
self.head_dim = self.config.hidden_size_per_head
self.num_heads = self.config.mamba_num_heads
self.time_step_rank = max(64, self.hidden_size // 16)
@ -197,6 +194,22 @@ class Plamo2MambaMixer(nn.Module):
self.C_norm = RMSNorm(self.ssm_state_size,
eps=self.config.rms_norm_eps)
self.chunk_size = self.config.mamba_chunk_size
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert self.chunk_size != -1, "chunk_size must be set for v1"
self.prefix = prefix
def _project_ssm_parameters(self, hidden_states):
ssm_parameters = self.bcdt_proj(hidden_states)
B, C, time_step = torch.split(
@ -212,25 +225,76 @@ class Plamo2MambaMixer(nn.Module):
dt = self.dt_proj(time_step)
return B, C, dt
def forward(
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
) -> torch.Tensor:
):
pass
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata)
else:
torch.ops.vllm.plamo2_mamba_mixer(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx
chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
@ -240,23 +304,59 @@ class Plamo2MambaMixer(nn.Module):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
0, 1)).contiguous()
output[:] = self.out_proj(hidden_states)
return
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes],
dim=0)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
if envs.VLLM_USE_V1:
hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decodes],
dim=0)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
@ -268,25 +368,38 @@ class Plamo2MambaMixer(nn.Module):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
# pointed to by "state_indices_tensor"
x = hidden_states_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_p = causal_conv1d_fn(
hidden_states_p.transpose(0, 1),
x,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
query_start_loc=query_start_loc_p)
hidden_states_p = hidden_states_p.transpose(0, 1)
hidden_states_p = hidden_states_p[:num_prefill_tokens]
@ -299,12 +412,16 @@ class Plamo2MambaMixer(nn.Module):
# 3. State Space Model sequence transformation
initial_states = None
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
if has_initial_states_p is not None and prep_initial_states:
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
if envs.VLLM_USE_V1:
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
@ -313,15 +430,15 @@ class Plamo2MambaMixer(nn.Module):
self.A,
B.view(1, num_prefill_tokens, 1, -1),
C.view(1, num_prefill_tokens, 1, -1),
chunk_size=mamba2_metadata.chunk_size,
chunk_size=chunk_size,
D=self.D,
z=gate_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size, self.head_dim),
dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
seq_idx=seq_idx_p,
chunk_indices=chunk_indices_p,
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
@ -329,18 +446,19 @@ class Plamo2MambaMixer(nn.Module):
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype,
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
ssm_state[state_indices_tensor_p] = varlen_state
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_d = causal_conv1d_update(
hidden_states_d,
mamba_cache_params.conv_state,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
@ -363,8 +481,10 @@ class Plamo2MambaMixer(nn.Module):
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
selective_state_update(
mamba_cache_params.ssm_state,
ssm_state,
hidden_states_d,
dt,
A,
@ -378,11 +498,68 @@ class Plamo2MambaMixer(nn.Module):
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
assert self.num_heads % self.tp_size == 0
# 4. Final linear projection
out = self.out_proj(preallocated_ssm_out)
return out
output[:num_actual_tokens] = self.out_proj(preallocated_ssm_out)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba2_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=0,
num_heads=self.num_heads,
head_dim=self.head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
@property
def mamba_type(self) -> str:
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
return Mamba2AttentionBackend
def plamo2_mamba_mixer(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None)
def plamo2_mamba_mixer_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="plamo2_mamba_mixer",
op_func=plamo2_mamba_mixer,
mutates_args=["output"],
fake_impl=plamo2_mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)
class DenseMLP(nn.Module):
@ -418,7 +595,6 @@ class DenseMLP(nn.Module):
return self.down_proj(h)
@support_torch_compile
class Plamo2AttentionMixer(nn.Module):
def __init__(self,
@ -575,12 +751,24 @@ class Plamo2DecoderLayer(nn.Module):
hidden_states, residual = self.pre_mixer_norm(
hidden_states, residual)
if self.is_mamba:
# Plamo2MambaMixer writes output to this tensor
output = torch.empty_like(hidden_states)
mixer_kwargs = {
"output": output,
"mamba_cache_params": mamba_cache_params,
"mamba2_metadata": mamba2_metadata,
}
else:
mixer_kwargs = {
"positions": positions,
}
hidden_states = self.mixer(
positions=positions,
hidden_states=hidden_states,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
**mixer_kwargs,
)
if self.is_mamba:
hidden_states = output
hidden_states = self.post_mixer_norm(hidden_states)
# Fully Connected
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
@ -591,7 +779,7 @@ class Plamo2DecoderLayer(nn.Module):
class Plamo2Decoder(torch.nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
@ -617,7 +805,7 @@ class Plamo2Decoder(torch.nn.Module):
mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None
if layer.is_mamba:
if layer.is_mamba and mamba_cache_params is not None:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
mamba_cache_index)
mamba_cache_index += 1
@ -632,10 +820,11 @@ class Plamo2Decoder(torch.nn.Module):
return hidden_states, residual
class Plamo2Model(Plamo2PreTrainedModel):
@support_torch_compile
class Plamo2Model(torch.nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config.model_config.hf_config)
super().__init__()
config = vllm_config.model_config.hf_config
@ -653,9 +842,9 @@ class Plamo2Model(Plamo2PreTrainedModel):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
self.layers = Plamo2Decoder(vllm_config=vllm_config,
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@ -679,11 +868,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
attn_metadata: AttentionMetadata = get_forward_context(
).attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
hidden_states, residual = self.layers(
positions=positions,
@ -701,8 +895,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
return hidden_states
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
IsHybrid, SupportsV0Only):
class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -712,12 +905,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
scheduler_config = vllm_config.scheduler_config
assert not vllm_config.cache_config.enable_prefix_caching, \
"PLaMo2 currently does not support prefix caching"
super().__init__(config)
self.config = config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
@ -751,8 +942,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@ -763,19 +952,27 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config,
num_mamba_layers,
*self._get_mamba_cache_shape(),
self.lm_head.weight.dtype,
self.lm_head.weight.dtype,
)
mamba_state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers,
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
@ -788,21 +985,48 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head)
conv_state_shape = (
hidden_size // world_size,
self.config.mamba_d_conv - 1,
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
temporal_state_shape = (
divide(self.config.mamba_num_heads, world_size),
self.config.hidden_size_per_head,
self.config.mamba_d_state,
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size =\
hf_config.mamba_num_heads * hf_config.hidden_size_per_head
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=0,
num_heads=hf_config.mamba_num_heads,
head_dim=hf_config.hidden_size_per_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,