mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 18:29:08 +08:00
[Model] Support TP/PP/mamba2 kernel for PLaMo2 (#19674)
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
This commit is contained in:
parent
15a72ac478
commit
c7ffe93d9c
@ -389,7 +389,7 @@ th {
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
|
||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | |
|
||||
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -175,6 +175,7 @@ TEXT_GENERATION_MODELS = {
|
||||
"internlm/internlm2-chat-7b": PPTestSettings.fast(),
|
||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||
"pfnet/plamo-2-1b": PPTestSettings.fast(),
|
||||
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
||||
# Tests TransformersForCausalLM
|
||||
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
|
||||
|
||||
@ -9,7 +9,7 @@ import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||
MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only PLaMo2 model."""
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
@ -11,30 +10,40 @@ from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
selective_state_update)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsV0Only)
|
||||
SupportsPP, SupportsV0Only)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.models.utils import (
|
||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -77,17 +86,6 @@ class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
|
||||
dt_min = 0.001
|
||||
dt_max = 0.1
|
||||
dt = torch.exp(
|
||||
torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) +
|
||||
math.log(dt_min))
|
||||
dt = torch.clamp(dt, 1e-4)
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
return inv_dt
|
||||
|
||||
|
||||
def is_mamba(config: Plamo2Config, i: int) -> bool:
|
||||
assert config.mamba_step > 1
|
||||
|
||||
@ -97,52 +95,36 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
|
||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||
|
||||
|
||||
# TODO(Shinichi): Replace this with RMSNorm.
|
||||
def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor,
|
||||
eps: float) -> torch.Tensor:
|
||||
input_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape)
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
hidden_states = weight * hidden_states
|
||||
return hidden_states.reshape(input_shape)
|
||||
|
||||
|
||||
def _swiglu(h: torch.Tensor) -> torch.Tensor:
|
||||
h0, h1 = h.chunk(2, dim=-1)
|
||||
return torch.nn.functional.silu(h0) * h1
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
# Adapted from:
|
||||
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
|
||||
# transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
class Plamo2MambaMixer(nn.Module):
|
||||
# TODO(Shinichi): Rebase on Mamba2 implementation.
|
||||
|
||||
def __init__(self,
|
||||
config: Plamo2Config,
|
||||
cache_config: CacheConfig,
|
||||
quant_config: QuantizationConfig,
|
||||
max_model_len: int,
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.ssm_state_size = config.mamba_d_state
|
||||
self.conv_kernel_size = config.mamba_d_conv
|
||||
self.intermediate_size = (config.mamba_num_heads *
|
||||
config.hidden_size_per_head)
|
||||
self.hidden_size_per_head = config.hidden_size_per_head
|
||||
self.num_heads = config.mamba_num_heads
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.hidden_size = self.config.hidden_size
|
||||
self.ssm_state_size = self.config.mamba_d_state
|
||||
self.conv_kernel_size = self.config.mamba_d_conv
|
||||
self.intermediate_size = (self.config.mamba_num_heads *
|
||||
self.config.hidden_size_per_head)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.intermediate_size_per_tp_worker = \
|
||||
self.intermediate_size // self.tp_size
|
||||
self.head_dim = self.config.hidden_size_per_head
|
||||
self.num_heads = self.config.mamba_num_heads
|
||||
self.time_step_rank = max(64, self.hidden_size // 16)
|
||||
self.use_conv_bias = False
|
||||
self.use_bias = False
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=self.conv_kernel_size,
|
||||
output_size=self.intermediate_size,
|
||||
bias=self.use_conv_bias,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
return_bias=False,
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
@ -153,15 +135,19 @@ class Plamo2MambaMixer(nn.Module):
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=self.use_bias,
|
||||
bias=False,
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.bcdt_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.time_step_rank + self.ssm_state_size * 2,
|
||||
bias=False,
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{prefix}.bcdt_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
@ -170,154 +156,224 @@ class Plamo2MambaMixer(nn.Module):
|
||||
self.time_step_rank,
|
||||
self.num_heads,
|
||||
bias=False,
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{prefix}.dt_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
self.intermediate_size // tp_size,
|
||||
self.ssm_state_size,
|
||||
divide(self.num_heads, self.tp_size),
|
||||
dtype=torch.float32,
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
|
||||
self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size)))
|
||||
self.dt_bias = nn.Parameter(
|
||||
torch.ones(divide(self.num_heads, self.tp_size)))
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||
set_weight_attrs(self.dt_bias,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=self.use_bias,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
# The activation function is fixed to SiLU.
|
||||
self.activation = "silu"
|
||||
|
||||
self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
|
||||
self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
||||
self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
||||
self.dt_norm = RMSNorm(self.time_step_rank,
|
||||
eps=self.config.rms_norm_eps)
|
||||
self.B_norm = RMSNorm(self.ssm_state_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
self.C_norm = RMSNorm(self.ssm_state_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0]
|
||||
# Reshaping the projected states as in modeling_plamo.py.
|
||||
length = len(hidden_states)
|
||||
projected_states = projected_states.reshape(length, self.num_heads, -1)
|
||||
gate, hidden_states = torch.split(
|
||||
projected_states,
|
||||
[self.hidden_size_per_head, self.hidden_size_per_head],
|
||||
dim=-1)
|
||||
hidden_states = hidden_states.reshape(length, -1).transpose(0, 1)
|
||||
gate = gate.reshape(length, -1).transpose(0, 1)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
# Splitting the ssm_parameters as in modeling_plamo.py.
|
||||
def _project_ssm_parameters(self, hidden_states):
|
||||
ssm_parameters = self.bcdt_proj(hidden_states)
|
||||
B, C, time_step = torch.split(
|
||||
ssm_parameters,
|
||||
[self.ssm_state_size, self.ssm_state_size, self.time_step_rank],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# vllm._custom_ops.rms_norm requires contiguous input tensors.
|
||||
time_step = self.dt_norm(time_step.contiguous())
|
||||
B = self.B_norm(B.contiguous())
|
||||
C = self.C_norm(C.contiguous())
|
||||
dt = self.dt_proj(time_step)
|
||||
return B, C, dt
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
time_proj_bias = (self.dt_bias.float() if hasattr(
|
||||
self.dt_proj, "bias") else None)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Broadcasting as in modeling_plamo.py.
|
||||
discrete_time_step = discrete_time_step.transpose(
|
||||
0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head)
|
||||
discrete_time_step = discrete_time_step.reshape(
|
||||
-1, self.intermediate_size).transpose(0, 1)
|
||||
time_proj_bias = time_proj_bias[...,
|
||||
None].expand(-1,
|
||||
self.hidden_size_per_head)
|
||||
time_proj_bias = time_proj_bias.reshape(self.intermediate_size)
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
mamba_cache_params.ssm_state,
|
||||
discrete_time_step,
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
gate, hidden_states = projected_states.chunk(2, dim=-1)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_p, hidden_states_d = torch.split(
|
||||
hidden_states,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes],
|
||||
dim=0)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
mamba_cache_params.state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
if has_prefill else None)
|
||||
|
||||
ssd_output_list = []
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
hidden_states_p = causal_conv1d_fn(
|
||||
hidden_states_p.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=mamba2_metadata.has_initial_states,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p)
|
||||
hidden_states_p = hidden_states_p.transpose(0, 1)
|
||||
hidden_states_p = hidden_states_p[:num_prefill_tokens]
|
||||
# In some instances, the following `bcdt_proj` op
|
||||
# requires contiguous inputs
|
||||
# (e.g. if the Marlin kernel is used).
|
||||
hidden_states_p = hidden_states_p.contiguous()
|
||||
|
||||
B, C, dt = self._project_ssm_parameters(hidden_states_p)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (mamba2_metadata.has_initial_states is not None
|
||||
and mamba2_metadata.prep_initial_states):
|
||||
# making a copy of the states
|
||||
initial_states = torch.where(
|
||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
dt.unsqueeze(0),
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
C.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
B.view(1, num_prefill_tokens, 1, -1),
|
||||
C.view(1, num_prefill_tokens, 1, -1),
|
||||
chunk_size=mamba2_metadata.chunk_size,
|
||||
D=self.D,
|
||||
z=gate_p.view(1, num_prefill_tokens,
|
||||
self.num_heads // self.tp_size, self.head_dim),
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=mamba2_metadata.seq_idx,
|
||||
chunk_indices=mamba2_metadata.chunk_indices,
|
||||
chunk_offsets=mamba2_metadata.chunk_offsets,
|
||||
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# - reshape
|
||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_d = causal_conv1d_update(
|
||||
hidden_states_d,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
|
||||
B, C, dt = self._project_ssm_parameters(hidden_states_d)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
A = self.A[:, None, ...][:, :,
|
||||
None].expand(-1, self.head_dim,
|
||||
self.config.mamba_d_state)
|
||||
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B = B.unsqueeze(1)
|
||||
C = C.unsqueeze(1)
|
||||
hidden_states_d = hidden_states_d.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
|
||||
hidden_states_d = selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
hidden_states_d,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
D,
|
||||
z=gate_d.reshape(num_decodes, -1, self.head_dim),
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
)
|
||||
assert self.num_heads % self.tp_size == 0
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
# Merge prefill and decode outputs before passing to MLP
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
||||
-1))[0]
|
||||
return contextualized_states
|
||||
out = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
|
||||
class DenseMLP(nn.Module):
|
||||
@ -332,33 +388,39 @@ class DenseMLP(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
self.hidden_size, [self.intermediate_size] * 2,
|
||||
self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
)
|
||||
self.act = SiluAndMul()
|
||||
self.down_proj = RowParallelLinear(self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
return_bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
h = self.gate_up_proj(hidden_states)[0]
|
||||
h = _swiglu(h)
|
||||
output, _ = self.down_proj(h)
|
||||
return output # type: ignore
|
||||
h = self.gate_up_proj(hidden_states)
|
||||
h = self.act(h)
|
||||
return self.down_proj(h)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Plamo2AttentionMixer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Plamo2Config,
|
||||
cache_config: CacheConfig,
|
||||
quant_config: QuantizationConfig,
|
||||
max_model_len: int | None = None,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
@ -396,19 +458,35 @@ class Plamo2AttentionMixer(nn.Module):
|
||||
"rope_theta") else 10000
|
||||
self.rope_scaling = config.rope_scaling if hasattr(
|
||||
config, "rope_scaling") else None
|
||||
max_position = config.max_position_embeddings
|
||||
if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
|
||||
vllm_config.model_config.max_model_len, int):
|
||||
max_position = min(max_position,
|
||||
vllm_config.model_config.max_model_len)
|
||||
|
||||
assert max_model_len is not None, "max_model_len must be provided"
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_model_len,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=self.rope_scaling,
|
||||
)
|
||||
self.q_weight = torch.nn.Parameter(
|
||||
self.q_norm = RMSNorm(config.hidden_size_per_head,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_norm.weight = torch.nn.Parameter(
|
||||
torch.ones((self.num_heads, config.hidden_size_per_head)))
|
||||
self.k_weight = torch.nn.Parameter(
|
||||
set_weight_attrs(self.q_norm.weight,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
self.k_norm = RMSNorm(config.hidden_size_per_head,
|
||||
eps=config.rms_norm_eps)
|
||||
self.k_norm.weight = torch.nn.Parameter(
|
||||
torch.ones((self.num_kv_heads, config.hidden_size_per_head)))
|
||||
# Tensor-parallelism shards the K norm weights to the tp ranks
|
||||
# in a head-wise manner. This approach does not work if there is only
|
||||
# a single KV head, as is the case for PLaMo 2-1B.
|
||||
if self.total_num_kv_heads != 1:
|
||||
set_weight_attrs(self.k_norm.weight,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
@ -423,13 +501,18 @@ class Plamo2AttentionMixer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = _rms_norm(q, self.q_weight, 1e-6)
|
||||
k = _rms_norm(k, self.k_weight, 1e-6)
|
||||
|
||||
q_shape = q.shape
|
||||
q = q.reshape(q_shape[:-1] + self.q_norm.weight.shape)
|
||||
q = self.q_norm.forward_native(q).reshape(q_shape)
|
||||
k_shape = k.shape
|
||||
k = k.reshape(k_shape[:-1] + self.k_norm.weight.shape)
|
||||
k = self.k_norm.forward_native(k).reshape(k_shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
@ -441,27 +524,18 @@ class Plamo2DecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
layer_idx: int,
|
||||
max_model_len: int | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
max_model_len = vllm_config.scheduler_config.max_model_len
|
||||
|
||||
self.is_mamba = is_mamba(config, layer_idx)
|
||||
if self.is_mamba:
|
||||
self.mixer = Plamo2MambaMixer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
max_model_len=max_model_len,
|
||||
self.mixer = Plamo2MambaMixer(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
else:
|
||||
self.mixer = Plamo2AttentionMixer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
max_model_len=max_model_len,
|
||||
self.mixer = Plamo2AttentionMixer(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
self.mlp = DenseMLP(config=config,
|
||||
@ -482,6 +556,7 @@ class Plamo2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -491,10 +566,12 @@ class Plamo2DecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.pre_mixer_norm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params)
|
||||
hidden_states = self.mixer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
hidden_states = self.post_mixer_norm(hidden_states)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
|
||||
@ -507,14 +584,18 @@ class Plamo2Decoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||
config = vllm_config.model_config.hf_config
|
||||
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
Plamo2DecoderLayer(vllm_config=vllm_config,
|
||||
layer_idx=i,
|
||||
prefix=f"{prefix}.layers.{i}")
|
||||
for i in range(num_hidden_layers)
|
||||
])
|
||||
def get_layer(prefix: str):
|
||||
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||
return Plamo2DecoderLayer(vllm_config=vllm_config,
|
||||
layer_idx=layer_idx,
|
||||
prefix=prefix,
|
||||
**extra_kwargs)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -522,9 +603,10 @@ class Plamo2Decoder(torch.nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
) -> torch.Tensor:
|
||||
mamba_cache_index = 0
|
||||
for layer in self.layers:
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
layer_mamba_cache_params = None
|
||||
if layer.is_mamba:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
@ -535,7 +617,9 @@ class Plamo2Decoder(torch.nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params)
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -557,10 +641,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -569,21 +659,41 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO(Shinichi): Implement pipeline parallelism.
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.layers(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params)
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
||||
SupportsV0Only):
|
||||
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
||||
IsHybrid, SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -629,10 +739,15 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.config.vocab_size)
|
||||
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@ -661,7 +776,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
self) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
hidden_size = (self.config.mamba_num_heads *
|
||||
self.config.hidden_size_per_head)
|
||||
@ -670,7 +785,8 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
||||
self.config.mamba_d_conv - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
hidden_size // world_size,
|
||||
divide(self.config.mamba_num_heads, world_size),
|
||||
self.config.hidden_size_per_head,
|
||||
self.config.mamba_d_state,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
@ -684,6 +800,14 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
@ -703,23 +827,46 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
||||
".B_norm_weight": ".B_norm.weight",
|
||||
".C_norm_weight": ".C_norm.weight",
|
||||
".dt_norm_weight": ".dt_norm.weight",
|
||||
".q_weight": ".q_norm.weight",
|
||||
".k_weight": ".k_norm.weight",
|
||||
}
|
||||
# Apply replacements based on the defined mappings
|
||||
for old, new in replacements.items():
|
||||
if old in name:
|
||||
name = name.replace(old, new)
|
||||
|
||||
# Broadcast the loaded weight to match the model's parameter shape.
|
||||
if ".A" in name:
|
||||
loaded_weight = loaded_weight[:, None, None].expand(
|
||||
-1, self.config.hidden_size_per_head,
|
||||
self.config.mamba_d_state)
|
||||
# Reshape the in_proj weights to match the shape expected
|
||||
# by MergedColumnParallelLinear.
|
||||
# This works both for unquantized weights and
|
||||
# for quantized weights.
|
||||
# In the quantized case, the weights are already transposed.
|
||||
# Also, in addition to the quantized weights,
|
||||
# the zero points and scales have to be reshaped as well.
|
||||
# Packing should not be affected by this.
|
||||
if ".mixer.in_proj.weight" in name \
|
||||
or "mixer.in_proj.qweight" in name \
|
||||
or "mixer.in_proj.scales" in name \
|
||||
or "mixer.in_proj.qzeros" in name:
|
||||
if "mixer.in_proj.weight" in name:
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
# for weight:
|
||||
# loaded_weight.shape[0] == self.config.hidden_size
|
||||
# for qweight:
|
||||
# loaded_weight.shape[0] == self.config.hidden_size // param.pack_factor # noqa
|
||||
# for scales and qzeros:
|
||||
# loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa
|
||||
loaded_weight = loaded_weight.reshape(
|
||||
-1, self.config.mamba_d_state)
|
||||
elif ".D" in name:
|
||||
loaded_weight = loaded_weight[:, None].expand(
|
||||
-1, self.config.hidden_size_per_head)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
loaded_weight.shape[0], self.config.mamba_num_heads, -1)
|
||||
gate_weight, hidden_states_weight = loaded_weight.chunk(2,
|
||||
dim=-1)
|
||||
gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1)
|
||||
hidden_states_weight = hidden_states_weight.reshape(
|
||||
loaded_weight.shape[0], -1)
|
||||
loaded_weight = torch.cat([gate_weight, hidden_states_weight],
|
||||
dim=-1)
|
||||
if "mixer.in_proj.weight" in name:
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
|
||||
# Offset parameter with vllm's RMSNorm haven't been supported yet.
|
||||
if ".pre_mixer_norm" in name:
|
||||
loaded_weight += 1.0
|
||||
@ -732,6 +879,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
||||
elif "model.norm.weight" in name:
|
||||
loaded_weight += 1.0
|
||||
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user