[Model] Support TP/PP/mamba2 kernel for PLaMo2 (#19674)

Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Co-authored-by: Calvin Metzger <metzger@preferred.jp>
Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
This commit is contained in:
Shinichi Hemmi 2025-07-28 14:00:47 +09:00 committed by GitHub
parent 15a72ac478
commit c7ffe93d9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 376 additions and 224 deletions

View File

@ -389,7 +389,7 @@ th {
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -175,6 +175,7 @@ TEXT_GENERATION_MODELS = {
"internlm/internlm2-chat-7b": PPTestSettings.fast(),
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"pfnet/plamo-2-1b": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersForCausalLM
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),

View File

@ -9,7 +9,7 @@ import pytest
from tests.quantization.utils import is_quant_method_supported
MODELS = ["ai21labs/Jamba-tiny-random"]
MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"]
@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only PLaMo2 model."""
import math
from collections.abc import Iterable
from typing import Optional
@ -11,30 +10,40 @@ from transformers import PretrainedConfig, PreTrainedModel
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
selective_state_update)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsV0Only)
SupportsPP, SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
@ -77,17 +86,6 @@ class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
module.weight.data[module.padding_idx].zero_()
def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
dt_min = 0.001
dt_max = 0.1
dt = torch.exp(
torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) +
math.log(dt_min))
dt = torch.clamp(dt, 1e-4)
inv_dt = dt + torch.log(-torch.expm1(-dt))
return inv_dt
def is_mamba(config: Plamo2Config, i: int) -> bool:
assert config.mamba_step > 1
@ -97,52 +95,36 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
return (i % config.mamba_step) != (config.mamba_step // 2)
# TODO(Shinichi): Replace this with RMSNorm.
def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor,
eps: float) -> torch.Tensor:
input_shape = hidden_states.shape
hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
hidden_states = hidden_states.to(input_dtype)
hidden_states = weight * hidden_states
return hidden_states.reshape(input_shape)
def _swiglu(h: torch.Tensor) -> torch.Tensor:
h0, h1 = h.chunk(2, dim=-1)
return torch.nn.functional.silu(h0) * h1
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# Adapted from:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer
class Plamo2MambaMixer(nn.Module):
# TODO(Shinichi): Rebase on Mamba2 implementation.
def __init__(self,
config: Plamo2Config,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
max_model_len: int,
vllm_config: VllmConfig,
*,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.intermediate_size = (config.mamba_num_heads *
config.hidden_size_per_head)
self.hidden_size_per_head = config.hidden_size_per_head
self.num_heads = config.mamba_num_heads
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.hidden_size = self.config.hidden_size
self.ssm_state_size = self.config.mamba_d_state
self.conv_kernel_size = self.config.mamba_d_conv
self.intermediate_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head)
self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size_per_tp_worker = \
self.intermediate_size // self.tp_size
self.head_dim = self.config.hidden_size_per_head
self.num_heads = self.config.mamba_num_heads
self.time_step_rank = max(64, self.hidden_size // 16)
self.use_conv_bias = False
self.use_bias = False
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
bias=self.use_conv_bias,
bias=False,
prefix=f"{prefix}.conv1d",
return_bias=False,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
@ -153,15 +135,19 @@ class Plamo2MambaMixer(nn.Module):
self.in_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size] * 2,
bias=self.use_bias,
bias=False,
quant_config=self.quant_config,
prefix=f"{prefix}.in_proj",
return_bias=False,
)
# selective projection used to make dt, B and C input dependent
self.bcdt_proj = RowParallelLinear(
self.intermediate_size,
self.time_step_rank + self.ssm_state_size * 2,
bias=False,
quant_config=self.quant_config,
prefix=f"{prefix}.bcdt_proj",
return_bias=False,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
@ -170,154 +156,224 @@ class Plamo2MambaMixer(nn.Module):
self.time_step_rank,
self.num_heads,
bias=False,
quant_config=self.quant_config,
prefix=f"{prefix}.dt_proj",
return_bias=False,
)
self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
self.intermediate_size // tp_size,
self.ssm_state_size,
divide(self.num_heads, self.tp_size),
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size)))
self.dt_bias = nn.Parameter(
torch.ones(divide(self.num_heads, self.tp_size)))
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias,
bias=False,
input_is_parallel=True,
quant_config=self.quant_config,
prefix=f"{prefix}.out_proj",
return_bias=False,
)
# The activation function is fixed to SiLU.
self.activation = "silu"
self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
self.dt_norm = RMSNorm(self.time_step_rank,
eps=self.config.rms_norm_eps)
self.B_norm = RMSNorm(self.ssm_state_size,
eps=self.config.rms_norm_eps)
self.C_norm = RMSNorm(self.ssm_state_size,
eps=self.config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
**kwargs,
) -> torch.Tensor:
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0]
# Reshaping the projected states as in modeling_plamo.py.
length = len(hidden_states)
projected_states = projected_states.reshape(length, self.num_heads, -1)
gate, hidden_states = torch.split(
projected_states,
[self.hidden_size_per_head, self.hidden_size_per_head],
dim=-1)
hidden_states = hidden_states.reshape(length, -1).transpose(0, 1)
gate = gate.reshape(length, -1).transpose(0, 1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0]
# Splitting the ssm_parameters as in modeling_plamo.py.
def _project_ssm_parameters(self, hidden_states):
ssm_parameters = self.bcdt_proj(hidden_states)
B, C, time_step = torch.split(
ssm_parameters,
[self.ssm_state_size, self.ssm_state_size, self.time_step_rank],
dim=-1,
)
# vllm._custom_ops.rms_norm requires contiguous input tensors.
time_step = self.dt_norm(time_step.contiguous())
B = self.B_norm(B.contiguous())
C = self.C_norm(C.contiguous())
dt = self.dt_proj(time_step)
return B, C, dt
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_bias.float() if hasattr(
self.dt_proj, "bias") else None)
def forward(
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
) -> torch.Tensor:
# Broadcasting as in modeling_plamo.py.
discrete_time_step = discrete_time_step.transpose(
0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head)
discrete_time_step = discrete_time_step.reshape(
-1, self.intermediate_size).transpose(0, 1)
time_proj_bias = time_proj_bias[...,
None].expand(-1,
self.hidden_size_per_head)
time_proj_bias = time_proj_bias.reshape(self.intermediate_size)
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
gate, hidden_states = projected_states.chunk(2, dim=-1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_p, hidden_states_d = torch.split(
hidden_states,
[num_prefill_tokens, num_decodes],
dim=0,
)
gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes],
dim=0)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
ssd_output_list = []
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_p = causal_conv1d_fn(
hidden_states_p.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p)
hidden_states_p = hidden_states_p.transpose(0, 1)
hidden_states_p = hidden_states_p[:num_prefill_tokens]
# In some instances, the following `bcdt_proj` op
# requires contiguous inputs
# (e.g. if the Marlin kernel is used).
hidden_states_p = hidden_states_p.contiguous()
B, C, dt = self._project_ssm_parameters(hidden_states_p)
# 3. State Space Model sequence transformation
initial_states = None
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt.unsqueeze(0),
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
B.view(1, num_prefill_tokens, 1, -1),
C.view(1, num_prefill_tokens, 1, -1),
chunk_size=mamba2_metadata.chunk_size,
D=self.D,
z=gate_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size, self.head_dim),
dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_d = causal_conv1d_update(
hidden_states_d,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d)
B, C, dt = self._project_ssm_parameters(hidden_states_d)
# 3. State Space Model sequence transformation
A = self.A[:, None, ...][:, :,
None].expand(-1, self.head_dim,
self.config.mamba_d_state)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.unsqueeze(1)
C = C.unsqueeze(1)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
hidden_states_d,
dt,
A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
D,
z=gate_d.reshape(num_decodes, -1, self.head_dim),
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
state_batch_indices=state_indices_tensor_d,
)
assert self.num_heads % self.tp_size == 0
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to MLP
hidden_states = torch.vstack(ssd_output_list)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states
out = self.out_proj(hidden_states)
return out
class DenseMLP(nn.Module):
@ -332,33 +388,39 @@ class DenseMLP(nn.Module):
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
self.hidden_size, [self.intermediate_size] * 2,
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
prefix=f"{prefix}.gate_up_proj",
quant_config=quant_config)
quant_config=quant_config,
return_bias=False,
)
self.act = SiluAndMul()
self.down_proj = RowParallelLinear(self.intermediate_size,
self.hidden_size,
bias=False,
prefix=f"{prefix}.down_proj",
quant_config=quant_config)
quant_config=quant_config,
return_bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
h = self.gate_up_proj(hidden_states)[0]
h = _swiglu(h)
output, _ = self.down_proj(h)
return output # type: ignore
h = self.gate_up_proj(hidden_states)
h = self.act(h)
return self.down_proj(h)
@support_torch_compile
class Plamo2AttentionMixer(nn.Module):
def __init__(self,
config: Plamo2Config,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
max_model_len: int | None = None,
*,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
@ -396,19 +458,35 @@ class Plamo2AttentionMixer(nn.Module):
"rope_theta") else 10000
self.rope_scaling = config.rope_scaling if hasattr(
config, "rope_scaling") else None
max_position = config.max_position_embeddings
if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
vllm_config.model_config.max_model_len, int):
max_position = min(max_position,
vllm_config.model_config.max_model_len)
assert max_model_len is not None, "max_model_len must be provided"
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_model_len,
max_position=max_position,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
)
self.q_weight = torch.nn.Parameter(
self.q_norm = RMSNorm(config.hidden_size_per_head,
eps=config.rms_norm_eps)
self.q_norm.weight = torch.nn.Parameter(
torch.ones((self.num_heads, config.hidden_size_per_head)))
self.k_weight = torch.nn.Parameter(
set_weight_attrs(self.q_norm.weight,
{"weight_loader": sharded_weight_loader(0)})
self.k_norm = RMSNorm(config.hidden_size_per_head,
eps=config.rms_norm_eps)
self.k_norm.weight = torch.nn.Parameter(
torch.ones((self.num_kv_heads, config.hidden_size_per_head)))
# Tensor-parallelism shards the K norm weights to the tp ranks
# in a head-wise manner. This approach does not work if there is only
# a single KV head, as is the case for PLaMo 2-1B.
if self.total_num_kv_heads != 1:
set_weight_attrs(self.k_norm.weight,
{"weight_loader": sharded_weight_loader(0)})
self.attn = Attention(
self.num_heads,
@ -423,13 +501,18 @@ class Plamo2AttentionMixer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = _rms_norm(q, self.q_weight, 1e-6)
k = _rms_norm(k, self.k_weight, 1e-6)
q_shape = q.shape
q = q.reshape(q_shape[:-1] + self.q_norm.weight.shape)
q = self.q_norm.forward_native(q).reshape(q_shape)
k_shape = k.shape
k = k.reshape(k_shape[:-1] + self.k_norm.weight.shape)
k = self.k_norm.forward_native(k).reshape(k_shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
@ -441,27 +524,18 @@ class Plamo2DecoderLayer(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
layer_idx: int,
max_model_len: int | None = None,
prefix: str = "",
**kwargs) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
max_model_len = vllm_config.scheduler_config.max_model_len
self.is_mamba = is_mamba(config, layer_idx)
if self.is_mamba:
self.mixer = Plamo2MambaMixer(config=config,
cache_config=cache_config,
quant_config=quant_config,
max_model_len=max_model_len,
self.mixer = Plamo2MambaMixer(vllm_config=vllm_config,
prefix=f"{prefix}.mixer")
else:
self.mixer = Plamo2AttentionMixer(config=config,
cache_config=cache_config,
quant_config=quant_config,
max_model_len=max_model_len,
self.mixer = Plamo2AttentionMixer(vllm_config=vllm_config,
prefix=f"{prefix}.mixer")
self.mlp = DenseMLP(config=config,
@ -482,6 +556,7 @@ class Plamo2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
@ -491,10 +566,12 @@ class Plamo2DecoderLayer(nn.Module):
hidden_states, residual = self.pre_mixer_norm(
hidden_states, residual)
hidden_states = self.mixer(positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params)
hidden_states = self.mixer(
positions=positions,
hidden_states=hidden_states,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
hidden_states = self.post_mixer_norm(hidden_states)
# Fully Connected
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
@ -507,14 +584,18 @@ class Plamo2Decoder(torch.nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
config = vllm_config.model_config.hf_config
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
self.layers = nn.ModuleList([
Plamo2DecoderLayer(vllm_config=vllm_config,
layer_idx=i,
prefix=f"{prefix}.layers.{i}")
for i in range(num_hidden_layers)
])
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
return Plamo2DecoderLayer(vllm_config=vllm_config,
layer_idx=layer_idx,
prefix=prefix,
**extra_kwargs)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
def forward(
self,
@ -522,9 +603,10 @@ class Plamo2Decoder(torch.nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
mamba_cache_index = 0
for layer in self.layers:
for layer in self.layers[self.start_layer:self.end_layer]:
layer_mamba_cache_params = None
if layer.is_mamba:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
@ -535,7 +617,9 @@ class Plamo2Decoder(torch.nn.Module):
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
return hidden_states, residual
@ -557,10 +641,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
org_num_embeddings=config.vocab_size,
prefix=f"{prefix}.embed_tokens",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -569,21 +659,41 @@ class Plamo2Model(Plamo2PreTrainedModel):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO(Shinichi): Implement pipeline parallelism.
hidden_states = self.embed_tokens(input_ids)
residual = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
hidden_states, residual = self.layers(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params)
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
SupportsV0Only):
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
IsHybrid, SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -629,10 +739,15 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
@ -661,7 +776,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head)
@ -670,7 +785,8 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
self.config.mamba_d_conv - 1,
)
temporal_state_shape = (
hidden_size // world_size,
divide(self.config.mamba_num_heads, world_size),
self.config.hidden_size_per_head,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
@ -684,6 +800,14 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
@ -703,23 +827,46 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
".B_norm_weight": ".B_norm.weight",
".C_norm_weight": ".C_norm.weight",
".dt_norm_weight": ".dt_norm.weight",
".q_weight": ".q_norm.weight",
".k_weight": ".k_norm.weight",
}
# Apply replacements based on the defined mappings
for old, new in replacements.items():
if old in name:
name = name.replace(old, new)
# Broadcast the loaded weight to match the model's parameter shape.
if ".A" in name:
loaded_weight = loaded_weight[:, None, None].expand(
-1, self.config.hidden_size_per_head,
self.config.mamba_d_state)
# Reshape the in_proj weights to match the shape expected
# by MergedColumnParallelLinear.
# This works both for unquantized weights and
# for quantized weights.
# In the quantized case, the weights are already transposed.
# Also, in addition to the quantized weights,
# the zero points and scales have to be reshaped as well.
# Packing should not be affected by this.
if ".mixer.in_proj.weight" in name \
or "mixer.in_proj.qweight" in name \
or "mixer.in_proj.scales" in name \
or "mixer.in_proj.qzeros" in name:
if "mixer.in_proj.weight" in name:
loaded_weight = loaded_weight.transpose(0, 1)
# for weight:
# loaded_weight.shape[0] == self.config.hidden_size
# for qweight:
# loaded_weight.shape[0] == self.config.hidden_size // param.pack_factor # noqa
# for scales and qzeros:
# loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa
loaded_weight = loaded_weight.reshape(
-1, self.config.mamba_d_state)
elif ".D" in name:
loaded_weight = loaded_weight[:, None].expand(
-1, self.config.hidden_size_per_head)
loaded_weight = loaded_weight.reshape(-1)
loaded_weight.shape[0], self.config.mamba_num_heads, -1)
gate_weight, hidden_states_weight = loaded_weight.chunk(2,
dim=-1)
gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1)
hidden_states_weight = hidden_states_weight.reshape(
loaded_weight.shape[0], -1)
loaded_weight = torch.cat([gate_weight, hidden_states_weight],
dim=-1)
if "mixer.in_proj.weight" in name:
loaded_weight = loaded_weight.transpose(0, 1)
# Offset parameter with vllm's RMSNorm haven't been supported yet.
if ".pre_mixer_norm" in name:
loaded_weight += 1.0
@ -732,6 +879,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
elif "model.norm.weight" in name:
loaded_weight += 1.0
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)