[V1] v1 engine + full CUDA graph support for PLaMo2 (#23998)

Signed-off-by: Hemmi Shinichi <shemmi@preferred.jp>
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
Co-authored-by: Hemmi Shinichi <shemmi@preferred.jp>
Co-authored-by: Thomas Parnell <tom.parnell@gmail.com>
This commit is contained in:
nopperl 2025-09-04 00:24:02 +09:00 committed by GitHub
parent 6d80ae83e1
commit fa4311d85f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 349 additions and 125 deletions

View File

@ -395,7 +395,7 @@ th {
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | | | `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -110,7 +110,7 @@ Models using selective state-space mechanisms instead of standard transformer at
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported. Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported.
Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`, `Plamo2ForCausalLM`).
Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`). Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`).

View File

@ -25,8 +25,7 @@ SSM_MODELS = [
HYBRID_MODELS = [ HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev", "ai21labs/Jamba-tiny-dev",
# skipping until vLLM implementation issues are resolved "pfnet/plamo-2-1b",
# "pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM", "hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview", "ibm-granite/granite-4.0-tiny-preview",
@ -37,6 +36,7 @@ HYBRID_MODELS = [
V1_SUPPORTED_MODELS = [ V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev", "ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"yujiepan/mamba2-codestral-v0.1-tiny-random", "yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM", "hmellor/tiny-random-BambaForCausalLM",
@ -47,6 +47,7 @@ V1_SUPPORTED_MODELS = [
FULL_CUDA_GRAPH_MODELS = [ FULL_CUDA_GRAPH_MODELS = [
"ai21labs/Jamba-tiny-dev", "ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct", "Zyphra/Zamba2-1.2B-instruct",
] ]

View File

@ -287,8 +287,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True), trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.53",
transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53", max_transformers_version="4.53",

View File

@ -340,6 +340,7 @@ class CompilationConfig:
"vllm.mamba_mixer", "vllm.mamba_mixer",
"vllm.short_conv", "vllm.short_conv",
"vllm.linear_attention", "vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
] ]
def compute_hash(self) -> str: def compute_hash(self) -> str:

View File

@ -3,19 +3,24 @@
"""Inference-only PLaMo2 model.""" """Inference-only PLaMo2 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice from itertools import islice
from typing import Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -23,8 +28,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import ( from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata) Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@ -39,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader) composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsPP, SupportsV0Only) SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
@ -47,8 +55,10 @@ from vllm.model_executor.models.utils import (
make_layers, maybe_prefix) make_layers, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType from vllm.utils import LayerBlockType, direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Only used for type hinting. # Only used for type hinting.
@ -73,20 +83,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore
vocab_size: int vocab_size: int
class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
def _init_weights(self, module: torch.nn.Module) -> None:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def is_mamba(config: Plamo2Config, i: int) -> bool: def is_mamba(config: Plamo2Config, i: int) -> bool:
assert config.mamba_step > 1 assert config.mamba_step > 1
@ -99,7 +95,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# Adapted from: # Adapted from:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer # transformers.models.mamba.modeling_mamba.MambaMixer
class Plamo2MambaMixer(nn.Module): @CustomOp.register(name="plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp):
def __init__(self, def __init__(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@ -108,6 +105,8 @@ class Plamo2MambaMixer(nn.Module):
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.quant_config = vllm_config.quant_config self.quant_config = vllm_config.quant_config
self.hidden_size = self.config.hidden_size self.hidden_size = self.config.hidden_size
self.ssm_state_size = self.config.mamba_d_state self.ssm_state_size = self.config.mamba_d_state
@ -115,8 +114,6 @@ class Plamo2MambaMixer(nn.Module):
self.intermediate_size = (self.config.mamba_num_heads * self.intermediate_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head) self.config.hidden_size_per_head)
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size_per_tp_worker = \
self.intermediate_size // self.tp_size
self.head_dim = self.config.hidden_size_per_head self.head_dim = self.config.hidden_size_per_head
self.num_heads = self.config.mamba_num_heads self.num_heads = self.config.mamba_num_heads
self.time_step_rank = max(64, self.hidden_size // 16) self.time_step_rank = max(64, self.hidden_size // 16)
@ -197,6 +194,22 @@ class Plamo2MambaMixer(nn.Module):
self.C_norm = RMSNorm(self.ssm_state_size, self.C_norm = RMSNorm(self.ssm_state_size,
eps=self.config.rms_norm_eps) eps=self.config.rms_norm_eps)
self.chunk_size = self.config.mamba_chunk_size
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert self.chunk_size != -1, "chunk_size must be set for v1"
self.prefix = prefix
def _project_ssm_parameters(self, hidden_states): def _project_ssm_parameters(self, hidden_states):
ssm_parameters = self.bcdt_proj(hidden_states) ssm_parameters = self.bcdt_proj(hidden_states)
B, C, time_step = torch.split( B, C, time_step = torch.split(
@ -212,25 +225,76 @@ class Plamo2MambaMixer(nn.Module):
dt = self.dt_proj(time_step) dt = self.dt_proj(time_step)
return B, C, dt return B, C, dt
def forward( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata, mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
) -> torch.Tensor: ):
pass
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata)
else:
torch.ops.vllm.plamo2_mamba_mixer(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton # mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they # modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration # stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
num_prefills = attn_metadata.num_prefills # request count if attn_metadata is not None:
num_decodes = attn_metadata.num_decode_tokens # token count (=request) assert isinstance(attn_metadata, dict)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count attn_metadata = attn_metadata[self.prefix]
has_prefill = num_prefills > 0 mamba2_metadata = attn_metadata
has_decode = num_decodes > 0 assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx
chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states) projected_states = self.in_proj(hidden_states)
@ -240,22 +304,58 @@ class Plamo2MambaMixer(nn.Module):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
0, 1)).contiguous()
output[:] = self.out_proj(hidden_states)
return
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input # Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_d, hidden_states_p = torch.split(
hidden_states[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_p, hidden_states_d = torch.split( hidden_states_p, hidden_states_d = torch.split(
hidden_states, hidden_states,
[num_prefill_tokens, num_decodes], [num_prefill_tokens, num_decodes],
dim=0, dim=0,
) )
gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes], gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decodes],
dim=0) dim=0)
# Split along batch dimension # Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split( state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor, state_indices_tensor,
[num_prefills, num_decodes], [num_prefills, num_decodes],
dim=0, dim=0,
) )
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None) if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill # Preallocate output tensor to avoid memcpy cost for merging prefill
@ -268,6 +368,13 @@ class Plamo2MambaMixer(nn.Module):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out, preallocated_ssm_out,
[num_prefill_tokens, num_decodes], [num_prefill_tokens, num_decodes],
@ -278,15 +385,21 @@ class Plamo2MambaMixer(nn.Module):
if has_prefill: if has_prefill:
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor" # pointed to by "state_indices_tensor"
x = hidden_states_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_p = causal_conv1d_fn( hidden_states_p = causal_conv1d_fn(
hidden_states_p.transpose(0, 1), x,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=mamba_cache_params.conv_state, conv_states=conv_state,
has_initial_state=mamba2_metadata.has_initial_states, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
query_start_loc=query_start_loc_p) query_start_loc=query_start_loc_p)
hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p.transpose(0, 1)
hidden_states_p = hidden_states_p[:num_prefill_tokens] hidden_states_p = hidden_states_p[:num_prefill_tokens]
@ -299,12 +412,16 @@ class Plamo2MambaMixer(nn.Module):
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
initial_states = None initial_states = None
if (mamba2_metadata.has_initial_states is not None if has_initial_states_p is not None and prep_initial_states:
and mamba2_metadata.prep_initial_states):
# making a copy of the states # making a copy of the states
if envs.VLLM_USE_V1:
initial_states = torch.where( initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None], has_initial_states_p[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0) ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
varlen_state = mamba_chunk_scan_combined( varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens, hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size, self.num_heads // self.tp_size,
@ -313,15 +430,15 @@ class Plamo2MambaMixer(nn.Module):
self.A, self.A,
B.view(1, num_prefill_tokens, 1, -1), B.view(1, num_prefill_tokens, 1, -1),
C.view(1, num_prefill_tokens, 1, -1), C.view(1, num_prefill_tokens, 1, -1),
chunk_size=mamba2_metadata.chunk_size, chunk_size=chunk_size,
D=self.D, D=self.D,
z=gate_p.view(1, num_prefill_tokens, z=gate_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size, self.head_dim), self.num_heads // self.tp_size, self.head_dim),
dt_bias=self.dt_bias, dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx, seq_idx=seq_idx_p,
chunk_indices=mamba2_metadata.chunk_indices, chunk_indices=chunk_indices_p,
chunk_offsets=mamba2_metadata.chunk_offsets, chunk_offsets=chunk_offsets_p,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], cu_seqlens=query_start_loc_p,
initial_states=initial_states, initial_states=initial_states,
return_varlen_states=True, return_varlen_states=True,
return_final_states=False, return_final_states=False,
@ -329,18 +446,19 @@ class Plamo2MambaMixer(nn.Module):
dt_limit=(0.0, float("inf")), dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim), self.head_dim),
state_dtype=ssm_state.dtype,
) )
# update ssm states # update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor # - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state ssm_state[state_indices_tensor_p] = varlen_state
# Process decode requests # Process decode requests
if has_decode: if has_decode:
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
hidden_states_d = causal_conv1d_update( hidden_states_d = causal_conv1d_update(
hidden_states_d, hidden_states_d,
mamba_cache_params.conv_state, conv_state,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
@ -363,8 +481,10 @@ class Plamo2MambaMixer(nn.Module):
# - the hidden is reshaped into (bs, num_heads, head_dim) # - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected # - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d # using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
selective_state_update( selective_state_update(
mamba_cache_params.ssm_state, ssm_state,
hidden_states_d, hidden_states_d,
dt, dt,
A, A,
@ -378,11 +498,68 @@ class Plamo2MambaMixer(nn.Module):
out=preallocated_ssm_out_d.view(num_decodes, -1, out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim), self.head_dim),
) )
assert self.num_heads % self.tp_size == 0
# 4. Final linear projection # 4. Final linear projection
out = self.out_proj(preallocated_ssm_out) output[:num_actual_tokens] = self.out_proj(preallocated_ssm_out)
return out
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba2_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=0,
num_heads=self.num_heads,
head_dim=self.head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
@property
def mamba_type(self) -> str:
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
return Mamba2AttentionBackend
def plamo2_mamba_mixer(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None)
def plamo2_mamba_mixer_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="plamo2_mamba_mixer",
op_func=plamo2_mamba_mixer,
mutates_args=["output"],
fake_impl=plamo2_mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)
class DenseMLP(nn.Module): class DenseMLP(nn.Module):
@ -418,7 +595,6 @@ class DenseMLP(nn.Module):
return self.down_proj(h) return self.down_proj(h)
@support_torch_compile
class Plamo2AttentionMixer(nn.Module): class Plamo2AttentionMixer(nn.Module):
def __init__(self, def __init__(self,
@ -575,12 +751,24 @@ class Plamo2DecoderLayer(nn.Module):
hidden_states, residual = self.pre_mixer_norm( hidden_states, residual = self.pre_mixer_norm(
hidden_states, residual) hidden_states, residual)
if self.is_mamba:
# Plamo2MambaMixer writes output to this tensor
output = torch.empty_like(hidden_states)
mixer_kwargs = {
"output": output,
"mamba_cache_params": mamba_cache_params,
"mamba2_metadata": mamba2_metadata,
}
else:
mixer_kwargs = {
"positions": positions,
}
hidden_states = self.mixer( hidden_states = self.mixer(
positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
mamba_cache_params=mamba_cache_params, **mixer_kwargs,
mamba2_metadata=mamba2_metadata,
) )
if self.is_mamba:
hidden_states = output
hidden_states = self.post_mixer_norm(hidden_states) hidden_states = self.post_mixer_norm(hidden_states)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
@ -591,7 +779,7 @@ class Plamo2DecoderLayer(nn.Module):
class Plamo2Decoder(torch.nn.Module): class Plamo2Decoder(torch.nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
@ -617,7 +805,7 @@ class Plamo2Decoder(torch.nn.Module):
mamba_cache_index = 0 mamba_cache_index = 0
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer_mamba_cache_params = None layer_mamba_cache_params = None
if layer.is_mamba: if layer.is_mamba and mamba_cache_params is not None:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx( layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
mamba_cache_index) mamba_cache_index)
mamba_cache_index += 1 mamba_cache_index += 1
@ -632,10 +820,11 @@ class Plamo2Decoder(torch.nn.Module):
return hidden_states, residual return hidden_states, residual
class Plamo2Model(Plamo2PreTrainedModel): @support_torch_compile
class Plamo2Model(torch.nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config.model_config.hf_config) super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -653,9 +842,9 @@ class Plamo2Model(Plamo2PreTrainedModel):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") self.layers = Plamo2Decoder(vllm_config=vllm_config,
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
@ -679,11 +868,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata if not envs.VLLM_USE_V1:
attn_metadata: AttentionMetadata = get_forward_context(
).attn_metadata
mamba2_metadata = prepare_mamba2_metadata( mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size, chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
hidden_states, residual = self.layers( hidden_states, residual = self.layers(
positions=positions, positions=positions,
@ -701,8 +895,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
return hidden_states return hidden_states
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
IsHybrid, SupportsV0Only):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -712,12 +905,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
assert not vllm_config.cache_config.enable_prefix_caching, \
"PLaMo2 currently does not support prefix caching"
super().__init__(config)
self.config = config self.config = config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
@ -751,8 +942,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
@ -763,19 +952,27 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None: if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type( num_mamba_layers = (
self.vllm_config.parallel_config, LayerBlockType.mamba) self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager( mamba_state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config, self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
num_mamba_layers, num_mamba_layers,
*self._get_mamba_cache_shape(), *mamba_state_shape,
self.lm_head.weight.dtype, *mamba_state_dtype)
self.lm_head.weight.dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.model(input_ids, positions, mamba_cache_params, hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
@ -788,21 +985,48 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( @classmethod
self) -> tuple[tuple[int, int], tuple[int, int, int]]: def get_mamba_state_dtype_from_config(
world_size = get_tensor_model_parallel_world_size() cls,
hidden_size = (self.config.mamba_num_heads * vllm_config: "VllmConfig",
self.config.hidden_size_per_head) ) -> tuple[torch.dtype, torch.dtype]:
conv_state_shape = (
hidden_size // world_size, return MambaStateDtypeCalculator.mamba2_state_dtype(
self.config.mamba_d_conv - 1, vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
) )
temporal_state_shape = (
divide(self.config.mamba_num_heads, world_size), @classmethod
self.config.hidden_size_per_head, def get_mamba_state_shape_from_config(
self.config.mamba_d_state, cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size =\
hf_config.mamba_num_heads * hf_config.hidden_size_per_head
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=0,
num_heads=hf_config.mamba_num_heads,
head_dim=hf_config.hidden_size_per_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
) )
return conv_state_shape, temporal_state_shape
def compute_logits( def compute_logits(
self, self,