mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 10:09:08 +08:00
[Model] Support TP/PP/mamba2 kernel for PLaMo2 (#19674)
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
This commit is contained in:
parent
15a72ac478
commit
c7ffe93d9c
@ -389,7 +389,7 @@ th {
|
|||||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
|
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | |
|
||||||
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -175,6 +175,7 @@ TEXT_GENERATION_MODELS = {
|
|||||||
"internlm/internlm2-chat-7b": PPTestSettings.fast(),
|
"internlm/internlm2-chat-7b": PPTestSettings.fast(),
|
||||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||||
|
"pfnet/plamo-2-1b": PPTestSettings.fast(),
|
||||||
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
||||||
# Tests TransformersForCausalLM
|
# Tests TransformersForCausalLM
|
||||||
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
|
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import pytest
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
|
||||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
|
@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Inference-only PLaMo2 model."""
|
"""Inference-only PLaMo2 model."""
|
||||||
import math
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -11,30 +10,40 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
selective_scan_fn, selective_state_update)
|
selective_state_update)
|
||||||
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
|
mamba_chunk_scan_combined)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||||
SupportsV0Only)
|
SupportsPP, SupportsV0Only)
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||||
MambaCacheParams)
|
MambaCacheParams)
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import (
|
||||||
|
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||||
|
make_layers, maybe_prefix)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -77,17 +86,6 @@ class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
|
|||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
|
|
||||||
dt_min = 0.001
|
|
||||||
dt_max = 0.1
|
|
||||||
dt = torch.exp(
|
|
||||||
torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) +
|
|
||||||
math.log(dt_min))
|
|
||||||
dt = torch.clamp(dt, 1e-4)
|
|
||||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
|
||||||
return inv_dt
|
|
||||||
|
|
||||||
|
|
||||||
def is_mamba(config: Plamo2Config, i: int) -> bool:
|
def is_mamba(config: Plamo2Config, i: int) -> bool:
|
||||||
assert config.mamba_step > 1
|
assert config.mamba_step > 1
|
||||||
|
|
||||||
@ -97,52 +95,36 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
|
|||||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||||
|
|
||||||
|
|
||||||
# TODO(Shinichi): Replace this with RMSNorm.
|
# Adapted from:
|
||||||
def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor,
|
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
|
||||||
eps: float) -> torch.Tensor:
|
# transformers.models.mamba.modeling_mamba.MambaMixer
|
||||||
input_shape = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape)
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
|
||||||
hidden_states = hidden_states.to(input_dtype)
|
|
||||||
hidden_states = weight * hidden_states
|
|
||||||
return hidden_states.reshape(input_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def _swiglu(h: torch.Tensor) -> torch.Tensor:
|
|
||||||
h0, h1 = h.chunk(2, dim=-1)
|
|
||||||
return torch.nn.functional.silu(h0) * h1
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
|
||||||
class Plamo2MambaMixer(nn.Module):
|
class Plamo2MambaMixer(nn.Module):
|
||||||
# TODO(Shinichi): Rebase on Mamba2 implementation.
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Plamo2Config,
|
vllm_config: VllmConfig,
|
||||||
cache_config: CacheConfig,
|
*,
|
||||||
quant_config: QuantizationConfig,
|
|
||||||
max_model_len: int,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = vllm_config.model_config.hf_config
|
||||||
self.hidden_size = config.hidden_size
|
self.quant_config = vllm_config.quant_config
|
||||||
self.ssm_state_size = config.mamba_d_state
|
self.hidden_size = self.config.hidden_size
|
||||||
self.conv_kernel_size = config.mamba_d_conv
|
self.ssm_state_size = self.config.mamba_d_state
|
||||||
self.intermediate_size = (config.mamba_num_heads *
|
self.conv_kernel_size = self.config.mamba_d_conv
|
||||||
config.hidden_size_per_head)
|
self.intermediate_size = (self.config.mamba_num_heads *
|
||||||
self.hidden_size_per_head = config.hidden_size_per_head
|
self.config.hidden_size_per_head)
|
||||||
self.num_heads = config.mamba_num_heads
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.intermediate_size_per_tp_worker = \
|
||||||
|
self.intermediate_size // self.tp_size
|
||||||
|
self.head_dim = self.config.hidden_size_per_head
|
||||||
|
self.num_heads = self.config.mamba_num_heads
|
||||||
self.time_step_rank = max(64, self.hidden_size // 16)
|
self.time_step_rank = max(64, self.hidden_size // 16)
|
||||||
self.use_conv_bias = False
|
|
||||||
self.use_bias = False
|
|
||||||
self.conv1d = ColumnParallelLinear(
|
self.conv1d = ColumnParallelLinear(
|
||||||
input_size=self.conv_kernel_size,
|
input_size=self.conv_kernel_size,
|
||||||
output_size=self.intermediate_size,
|
output_size=self.intermediate_size,
|
||||||
bias=self.use_conv_bias,
|
bias=False,
|
||||||
|
prefix=f"{prefix}.conv1d",
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||||
# Can't do this in `weight_loader` since it already exists in
|
# Can't do this in `weight_loader` since it already exists in
|
||||||
@ -153,15 +135,19 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
self.in_proj = MergedColumnParallelLinear(
|
self.in_proj = MergedColumnParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
[self.intermediate_size] * 2,
|
[self.intermediate_size] * 2,
|
||||||
bias=self.use_bias,
|
bias=False,
|
||||||
|
quant_config=self.quant_config,
|
||||||
prefix=f"{prefix}.in_proj",
|
prefix=f"{prefix}.in_proj",
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
# selective projection used to make dt, B and C input dependent
|
# selective projection used to make dt, B and C input dependent
|
||||||
self.bcdt_proj = RowParallelLinear(
|
self.bcdt_proj = RowParallelLinear(
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.time_step_rank + self.ssm_state_size * 2,
|
self.time_step_rank + self.ssm_state_size * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=self.quant_config,
|
||||||
prefix=f"{prefix}.bcdt_proj",
|
prefix=f"{prefix}.bcdt_proj",
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
# time step projection (discretization) -
|
# time step projection (discretization) -
|
||||||
# In the forward we need to apply dt_proj without the bias,
|
# In the forward we need to apply dt_proj without the bias,
|
||||||
@ -170,154 +156,224 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
self.time_step_rank,
|
self.time_step_rank,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=self.quant_config,
|
||||||
prefix=f"{prefix}.dt_proj",
|
prefix=f"{prefix}.dt_proj",
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
|
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.A = nn.Parameter(
|
self.A = nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
self.intermediate_size // tp_size,
|
divide(self.num_heads, self.tp_size),
|
||||||
self.ssm_state_size,
|
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
))
|
))
|
||||||
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
|
self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size)))
|
||||||
|
self.dt_bias = nn.Parameter(
|
||||||
|
torch.ones(divide(self.num_heads, self.tp_size)))
|
||||||
|
|
||||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||||
a_weight_loader = composed_weight_loader(
|
a_weight_loader = composed_weight_loader(
|
||||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
||||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||||
|
set_weight_attrs(self.dt_bias,
|
||||||
|
{"weight_loader": sharded_weight_loader(0)})
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=self.use_bias,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
|
quant_config=self.quant_config,
|
||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
# The activation function is fixed to SiLU.
|
# The activation function is fixed to SiLU.
|
||||||
self.activation = "silu"
|
self.activation = "silu"
|
||||||
|
|
||||||
self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
|
self.dt_norm = RMSNorm(self.time_step_rank,
|
||||||
self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
eps=self.config.rms_norm_eps)
|
||||||
self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
|
self.B_norm = RMSNorm(self.ssm_state_size,
|
||||||
|
eps=self.config.rms_norm_eps)
|
||||||
|
self.C_norm = RMSNorm(self.ssm_state_size,
|
||||||
|
eps=self.config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def _project_ssm_parameters(self, hidden_states):
|
||||||
self,
|
ssm_parameters = self.bcdt_proj(hidden_states)
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
|
||||||
projected_states = self.in_proj(hidden_states)[0]
|
|
||||||
# Reshaping the projected states as in modeling_plamo.py.
|
|
||||||
length = len(hidden_states)
|
|
||||||
projected_states = projected_states.reshape(length, self.num_heads, -1)
|
|
||||||
gate, hidden_states = torch.split(
|
|
||||||
projected_states,
|
|
||||||
[self.hidden_size_per_head, self.hidden_size_per_head],
|
|
||||||
dim=-1)
|
|
||||||
hidden_states = hidden_states.reshape(length, -1).transpose(0, 1)
|
|
||||||
gate = gate.reshape(length, -1).transpose(0, 1)
|
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
|
||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
|
||||||
self.conv1d.weight.size(2))
|
|
||||||
|
|
||||||
if attn_metadata.query_start_loc is not None \
|
|
||||||
and attn_metadata.context_lens_tensor is not None:
|
|
||||||
# |---------- N-1 iteration --------|
|
|
||||||
# |---------------- N iteration ---------------------|
|
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
|
||||||
# |---------- context_len ----------|
|
|
||||||
# |-------------------- seq_len ---------------------|
|
|
||||||
# |-- query_len ---|
|
|
||||||
hidden_states = causal_conv1d_fn(
|
|
||||||
hidden_states,
|
|
||||||
conv_weights,
|
|
||||||
self.conv1d.bias,
|
|
||||||
activation=self.activation,
|
|
||||||
conv_states=mamba_cache_params.conv_state,
|
|
||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
|
||||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
|
||||||
query_start_loc=attn_metadata.query_start_loc)
|
|
||||||
else:
|
|
||||||
hidden_states = causal_conv1d_update(
|
|
||||||
hidden_states.transpose(0, 1),
|
|
||||||
mamba_cache_params.conv_state,
|
|
||||||
conv_weights,
|
|
||||||
self.conv1d.bias,
|
|
||||||
self.activation,
|
|
||||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
|
||||||
hidden_states = hidden_states.transpose(0, 1)
|
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
|
||||||
ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0]
|
|
||||||
|
|
||||||
# Splitting the ssm_parameters as in modeling_plamo.py.
|
|
||||||
B, C, time_step = torch.split(
|
B, C, time_step = torch.split(
|
||||||
ssm_parameters,
|
ssm_parameters,
|
||||||
[self.ssm_state_size, self.ssm_state_size, self.time_step_rank],
|
[self.ssm_state_size, self.ssm_state_size, self.time_step_rank],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# vllm._custom_ops.rms_norm requires contiguous input tensors.
|
||||||
time_step = self.dt_norm(time_step.contiguous())
|
time_step = self.dt_norm(time_step.contiguous())
|
||||||
B = self.B_norm(B.contiguous())
|
B = self.B_norm(B.contiguous())
|
||||||
C = self.C_norm(C.contiguous())
|
C = self.C_norm(C.contiguous())
|
||||||
|
dt = self.dt_proj(time_step)
|
||||||
|
return B, C, dt
|
||||||
|
|
||||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
def forward(
|
||||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
self,
|
||||||
time_proj_bias = (self.dt_bias.float() if hasattr(
|
hidden_states: torch.Tensor,
|
||||||
self.dt_proj, "bias") else None)
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
mamba2_metadata: Mamba2Metadata,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# Broadcasting as in modeling_plamo.py.
|
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||||
discrete_time_step = discrete_time_step.transpose(
|
# kernels to operate in continuous batching and in chunked prefill
|
||||||
0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head)
|
# modes; they are computed at top-level model forward since they
|
||||||
discrete_time_step = discrete_time_step.reshape(
|
# stay the same and reused for all mamba layers in the same iteration
|
||||||
-1, self.intermediate_size).transpose(0, 1)
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
time_proj_bias = time_proj_bias[...,
|
|
||||||
None].expand(-1,
|
|
||||||
self.hidden_size_per_head)
|
|
||||||
time_proj_bias = time_proj_bias.reshape(self.intermediate_size)
|
|
||||||
|
|
||||||
if attn_metadata.query_start_loc is not None \
|
num_prefills = attn_metadata.num_prefills # request count
|
||||||
and attn_metadata.context_lens_tensor is not None:
|
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||||
scan_outputs = selective_scan_fn(
|
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||||
hidden_states,
|
has_prefill = num_prefills > 0
|
||||||
mamba_cache_params.ssm_state,
|
has_decode = num_decodes > 0
|
||||||
discrete_time_step,
|
|
||||||
|
# 1. Gated MLP's linear projection
|
||||||
|
projected_states = self.in_proj(hidden_states)
|
||||||
|
gate, hidden_states = projected_states.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
|
self.conv1d.weight.size(2))
|
||||||
|
|
||||||
|
# Separate prefill and decode by splitting varlen input
|
||||||
|
# Split along token dimension
|
||||||
|
hidden_states_p, hidden_states_d = torch.split(
|
||||||
|
hidden_states,
|
||||||
|
[num_prefill_tokens, num_decodes],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes],
|
||||||
|
dim=0)
|
||||||
|
# Split along batch dimension
|
||||||
|
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||||
|
mamba_cache_params.state_indices_tensor,
|
||||||
|
[num_prefills, num_decodes],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||||
|
if has_prefill else None)
|
||||||
|
|
||||||
|
ssd_output_list = []
|
||||||
|
|
||||||
|
# Process prefill requests
|
||||||
|
if has_prefill:
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
# - "cache_indices" updates the conv_state cache in positions
|
||||||
|
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||||
|
hidden_states_p = causal_conv1d_fn(
|
||||||
|
hidden_states_p.transpose(0, 1),
|
||||||
|
conv_weights,
|
||||||
|
self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
conv_states=mamba_cache_params.conv_state,
|
||||||
|
has_initial_state=mamba2_metadata.has_initial_states,
|
||||||
|
cache_indices=state_indices_tensor_p,
|
||||||
|
query_start_loc=query_start_loc_p)
|
||||||
|
hidden_states_p = hidden_states_p.transpose(0, 1)
|
||||||
|
hidden_states_p = hidden_states_p[:num_prefill_tokens]
|
||||||
|
# In some instances, the following `bcdt_proj` op
|
||||||
|
# requires contiguous inputs
|
||||||
|
# (e.g. if the Marlin kernel is used).
|
||||||
|
hidden_states_p = hidden_states_p.contiguous()
|
||||||
|
|
||||||
|
B, C, dt = self._project_ssm_parameters(hidden_states_p)
|
||||||
|
|
||||||
|
# 3. State Space Model sequence transformation
|
||||||
|
initial_states = None
|
||||||
|
if (mamba2_metadata.has_initial_states is not None
|
||||||
|
and mamba2_metadata.prep_initial_states):
|
||||||
|
# making a copy of the states
|
||||||
|
initial_states = torch.where(
|
||||||
|
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||||
|
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||||
|
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||||
|
hidden_states_p.view(1, num_prefill_tokens,
|
||||||
|
self.num_heads // self.tp_size,
|
||||||
|
self.head_dim),
|
||||||
|
dt.unsqueeze(0),
|
||||||
self.A,
|
self.A,
|
||||||
B.transpose(-2, -1),
|
B.view(1, num_prefill_tokens, 1, -1),
|
||||||
C.transpose(-2, -1),
|
C.view(1, num_prefill_tokens, 1, -1),
|
||||||
self.D.float(),
|
chunk_size=mamba2_metadata.chunk_size,
|
||||||
gate,
|
D=self.D,
|
||||||
time_proj_bias,
|
z=gate_p.view(1, num_prefill_tokens,
|
||||||
delta_softplus=True,
|
self.num_heads // self.tp_size, self.head_dim),
|
||||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
dt_bias=self.dt_bias,
|
||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
seq_idx=mamba2_metadata.seq_idx,
|
||||||
query_start_loc=attn_metadata.query_start_loc)
|
chunk_indices=mamba2_metadata.chunk_indices,
|
||||||
else:
|
chunk_offsets=mamba2_metadata.chunk_offsets,
|
||||||
scan_outputs = selective_state_update(
|
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
|
||||||
|
initial_states=initial_states,
|
||||||
|
return_varlen_states=True,
|
||||||
|
return_final_states=False,
|
||||||
|
dt_softplus=True,
|
||||||
|
dt_limit=(0.0, float("inf")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# update ssm states
|
||||||
|
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||||
|
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||||
|
|
||||||
|
# - reshape
|
||||||
|
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||||
|
|
||||||
|
# Process decode requests
|
||||||
|
if has_decode:
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
hidden_states_d = causal_conv1d_update(
|
||||||
|
hidden_states_d,
|
||||||
|
mamba_cache_params.conv_state,
|
||||||
|
conv_weights,
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
conv_state_indices=state_indices_tensor_d)
|
||||||
|
|
||||||
|
B, C, dt = self._project_ssm_parameters(hidden_states_d)
|
||||||
|
|
||||||
|
# 3. State Space Model sequence transformation
|
||||||
|
A = self.A[:, None, ...][:, :,
|
||||||
|
None].expand(-1, self.head_dim,
|
||||||
|
self.config.mamba_d_state)
|
||||||
|
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||||
|
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||||
|
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||||
|
B = B.unsqueeze(1)
|
||||||
|
C = C.unsqueeze(1)
|
||||||
|
hidden_states_d = hidden_states_d.view(
|
||||||
|
-1, self.num_heads // self.tp_size, self.head_dim)
|
||||||
|
|
||||||
|
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||||
|
# - mamba_cache_params.ssm_state's slots will be selected
|
||||||
|
# using state_indices_tensor_d
|
||||||
|
|
||||||
|
hidden_states_d = selective_state_update(
|
||||||
mamba_cache_params.ssm_state,
|
mamba_cache_params.ssm_state,
|
||||||
hidden_states.transpose(0, 1),
|
hidden_states_d,
|
||||||
discrete_time_step.transpose(0, 1),
|
dt,
|
||||||
self.A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
self.D,
|
D,
|
||||||
gate.transpose(0, 1),
|
z=gate_d.reshape(num_decodes, -1, self.head_dim),
|
||||||
time_proj_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
state_batch_indices=state_indices_tensor_d,
|
||||||
scan_outputs = scan_outputs.transpose(0, 1)
|
)
|
||||||
|
assert self.num_heads % self.tp_size == 0
|
||||||
|
ssd_output_list.append(
|
||||||
|
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||||
|
self.head_dim))
|
||||||
|
|
||||||
|
# Merge prefill and decode outputs before passing to MLP
|
||||||
|
hidden_states = torch.vstack(ssd_output_list)
|
||||||
|
|
||||||
# 4. Final linear projection
|
# 4. Final linear projection
|
||||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
out = self.out_proj(hidden_states)
|
||||||
-1))[0]
|
return out
|
||||||
return contextualized_states
|
|
||||||
|
|
||||||
|
|
||||||
class DenseMLP(nn.Module):
|
class DenseMLP(nn.Module):
|
||||||
@ -332,33 +388,39 @@ class DenseMLP(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
self.hidden_size, [self.intermediate_size] * 2,
|
self.hidden_size,
|
||||||
|
[self.intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.act = SiluAndMul()
|
||||||
self.down_proj = RowParallelLinear(self.intermediate_size,
|
self.down_proj = RowParallelLinear(self.intermediate_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
return_bias=False)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
h = self.gate_up_proj(hidden_states)[0]
|
h = self.gate_up_proj(hidden_states)
|
||||||
h = _swiglu(h)
|
h = self.act(h)
|
||||||
output, _ = self.down_proj(h)
|
return self.down_proj(h)
|
||||||
return output # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class Plamo2AttentionMixer(nn.Module):
|
class Plamo2AttentionMixer(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Plamo2Config,
|
*,
|
||||||
cache_config: CacheConfig,
|
vllm_config: VllmConfig,
|
||||||
quant_config: QuantizationConfig,
|
|
||||||
max_model_len: int | None = None,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.total_num_heads = config.num_attention_heads
|
self.total_num_heads = config.num_attention_heads
|
||||||
@ -396,19 +458,35 @@ class Plamo2AttentionMixer(nn.Module):
|
|||||||
"rope_theta") else 10000
|
"rope_theta") else 10000
|
||||||
self.rope_scaling = config.rope_scaling if hasattr(
|
self.rope_scaling = config.rope_scaling if hasattr(
|
||||||
config, "rope_scaling") else None
|
config, "rope_scaling") else None
|
||||||
|
max_position = config.max_position_embeddings
|
||||||
|
if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
|
||||||
|
vllm_config.model_config.max_model_len, int):
|
||||||
|
max_position = min(max_position,
|
||||||
|
vllm_config.model_config.max_model_len)
|
||||||
|
|
||||||
assert max_model_len is not None, "max_model_len must be provided"
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position=max_model_len,
|
max_position=max_position,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
rope_scaling=self.rope_scaling,
|
rope_scaling=self.rope_scaling,
|
||||||
)
|
)
|
||||||
self.q_weight = torch.nn.Parameter(
|
self.q_norm = RMSNorm(config.hidden_size_per_head,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.q_norm.weight = torch.nn.Parameter(
|
||||||
torch.ones((self.num_heads, config.hidden_size_per_head)))
|
torch.ones((self.num_heads, config.hidden_size_per_head)))
|
||||||
self.k_weight = torch.nn.Parameter(
|
set_weight_attrs(self.q_norm.weight,
|
||||||
|
{"weight_loader": sharded_weight_loader(0)})
|
||||||
|
self.k_norm = RMSNorm(config.hidden_size_per_head,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.k_norm.weight = torch.nn.Parameter(
|
||||||
torch.ones((self.num_kv_heads, config.hidden_size_per_head)))
|
torch.ones((self.num_kv_heads, config.hidden_size_per_head)))
|
||||||
|
# Tensor-parallelism shards the K norm weights to the tp ranks
|
||||||
|
# in a head-wise manner. This approach does not work if there is only
|
||||||
|
# a single KV head, as is the case for PLaMo 2-1B.
|
||||||
|
if self.total_num_kv_heads != 1:
|
||||||
|
set_weight_attrs(self.k_norm.weight,
|
||||||
|
{"weight_loader": sharded_weight_loader(0)})
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -423,13 +501,18 @@ class Plamo2AttentionMixer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q = _rms_norm(q, self.q_weight, 1e-6)
|
|
||||||
k = _rms_norm(k, self.k_weight, 1e-6)
|
q_shape = q.shape
|
||||||
|
q = q.reshape(q_shape[:-1] + self.q_norm.weight.shape)
|
||||||
|
q = self.q_norm.forward_native(q).reshape(q_shape)
|
||||||
|
k_shape = k.shape
|
||||||
|
k = k.reshape(k_shape[:-1] + self.k_norm.weight.shape)
|
||||||
|
k = self.k_norm.forward_native(k).reshape(k_shape)
|
||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
@ -441,27 +524,18 @@ class Plamo2DecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
max_model_len: int | None = None,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
max_model_len = vllm_config.scheduler_config.max_model_len
|
|
||||||
|
|
||||||
self.is_mamba = is_mamba(config, layer_idx)
|
self.is_mamba = is_mamba(config, layer_idx)
|
||||||
if self.is_mamba:
|
if self.is_mamba:
|
||||||
self.mixer = Plamo2MambaMixer(config=config,
|
self.mixer = Plamo2MambaMixer(vllm_config=vllm_config,
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
prefix=f"{prefix}.mixer")
|
prefix=f"{prefix}.mixer")
|
||||||
else:
|
else:
|
||||||
self.mixer = Plamo2AttentionMixer(config=config,
|
self.mixer = Plamo2AttentionMixer(vllm_config=vllm_config,
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
prefix=f"{prefix}.mixer")
|
prefix=f"{prefix}.mixer")
|
||||||
|
|
||||||
self.mlp = DenseMLP(config=config,
|
self.mlp = DenseMLP(config=config,
|
||||||
@ -482,6 +556,7 @@ class Plamo2DecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
mamba2_metadata: Mamba2Metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -491,10 +566,12 @@ class Plamo2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.pre_mixer_norm(
|
hidden_states, residual = self.pre_mixer_norm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mixer(positions=positions,
|
hidden_states = self.mixer(
|
||||||
hidden_states=hidden_states,
|
positions=positions,
|
||||||
residual=residual,
|
hidden_states=hidden_states,
|
||||||
mamba_cache_params=mamba_cache_params)
|
mamba_cache_params=mamba_cache_params,
|
||||||
|
mamba2_metadata=mamba2_metadata,
|
||||||
|
)
|
||||||
hidden_states = self.post_mixer_norm(hidden_states)
|
hidden_states = self.post_mixer_norm(hidden_states)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
|
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
|
||||||
@ -507,14 +584,18 @@ class Plamo2Decoder(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
config = vllm_config.model_config.hf_config
|
||||||
|
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
def get_layer(prefix: str):
|
||||||
Plamo2DecoderLayer(vllm_config=vllm_config,
|
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||||
layer_idx=i,
|
return Plamo2DecoderLayer(vllm_config=vllm_config,
|
||||||
prefix=f"{prefix}.layers.{i}")
|
layer_idx=layer_idx,
|
||||||
for i in range(num_hidden_layers)
|
prefix=prefix,
|
||||||
])
|
**extra_kwargs)
|
||||||
|
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -522,9 +603,10 @@ class Plamo2Decoder(torch.nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
mamba2_metadata: Mamba2Metadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
mamba_cache_index = 0
|
mamba_cache_index = 0
|
||||||
for layer in self.layers:
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer_mamba_cache_params = None
|
layer_mamba_cache_params = None
|
||||||
if layer.is_mamba:
|
if layer.is_mamba:
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||||
@ -535,7 +617,9 @@ class Plamo2Decoder(torch.nn.Module):
|
|||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=layer_mamba_cache_params)
|
mamba_cache_params=layer_mamba_cache_params,
|
||||||
|
mamba2_metadata=mamba2_metadata,
|
||||||
|
)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@ -557,10 +641,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
prefix=f"{prefix}.embed_tokens",
|
prefix=f"{prefix}.embed_tokens",
|
||||||
)
|
)
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
|
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -569,21 +659,41 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO(Shinichi): Implement pipeline parallelism.
|
if get_pp_group().is_first_rank:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if inputs_embeds is not None:
|
||||||
residual = None
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
|
mamba2_metadata = prepare_mamba2_metadata(
|
||||||
|
chunk_size=self.config.mamba_chunk_size,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.layers(
|
hidden_states, residual = self.layers(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=mamba_cache_params)
|
mamba_cache_params=mamba_cache_params,
|
||||||
|
mamba2_metadata=mamba2_metadata,
|
||||||
|
)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
||||||
SupportsV0Only):
|
IsHybrid, SupportsV0Only):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -629,10 +739,15 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
|||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
self.config.vocab_size)
|
self.config.vocab_size)
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -661,7 +776,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
|||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
def _get_mamba_cache_shape(
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
self) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
world_size = get_tensor_model_parallel_world_size()
|
||||||
hidden_size = (self.config.mamba_num_heads *
|
hidden_size = (self.config.mamba_num_heads *
|
||||||
self.config.hidden_size_per_head)
|
self.config.hidden_size_per_head)
|
||||||
@ -670,7 +785,8 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
|||||||
self.config.mamba_d_conv - 1,
|
self.config.mamba_d_conv - 1,
|
||||||
)
|
)
|
||||||
temporal_state_shape = (
|
temporal_state_shape = (
|
||||||
hidden_size // world_size,
|
divide(self.config.mamba_num_heads, world_size),
|
||||||
|
self.config.hidden_size_per_head,
|
||||||
self.config.mamba_d_state,
|
self.config.mamba_d_state,
|
||||||
)
|
)
|
||||||
return conv_state_shape, temporal_state_shape
|
return conv_state_shape, temporal_state_shape
|
||||||
@ -684,6 +800,14 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
|||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@ -703,23 +827,46 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
|||||||
".B_norm_weight": ".B_norm.weight",
|
".B_norm_weight": ".B_norm.weight",
|
||||||
".C_norm_weight": ".C_norm.weight",
|
".C_norm_weight": ".C_norm.weight",
|
||||||
".dt_norm_weight": ".dt_norm.weight",
|
".dt_norm_weight": ".dt_norm.weight",
|
||||||
|
".q_weight": ".q_norm.weight",
|
||||||
|
".k_weight": ".k_norm.weight",
|
||||||
}
|
}
|
||||||
# Apply replacements based on the defined mappings
|
# Apply replacements based on the defined mappings
|
||||||
for old, new in replacements.items():
|
for old, new in replacements.items():
|
||||||
if old in name:
|
if old in name:
|
||||||
name = name.replace(old, new)
|
name = name.replace(old, new)
|
||||||
|
|
||||||
# Broadcast the loaded weight to match the model's parameter shape.
|
# Reshape the in_proj weights to match the shape expected
|
||||||
if ".A" in name:
|
# by MergedColumnParallelLinear.
|
||||||
loaded_weight = loaded_weight[:, None, None].expand(
|
# This works both for unquantized weights and
|
||||||
-1, self.config.hidden_size_per_head,
|
# for quantized weights.
|
||||||
self.config.mamba_d_state)
|
# In the quantized case, the weights are already transposed.
|
||||||
|
# Also, in addition to the quantized weights,
|
||||||
|
# the zero points and scales have to be reshaped as well.
|
||||||
|
# Packing should not be affected by this.
|
||||||
|
if ".mixer.in_proj.weight" in name \
|
||||||
|
or "mixer.in_proj.qweight" in name \
|
||||||
|
or "mixer.in_proj.scales" in name \
|
||||||
|
or "mixer.in_proj.qzeros" in name:
|
||||||
|
if "mixer.in_proj.weight" in name:
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
# for weight:
|
||||||
|
# loaded_weight.shape[0] == self.config.hidden_size
|
||||||
|
# for qweight:
|
||||||
|
# loaded_weight.shape[0] == self.config.hidden_size // param.pack_factor # noqa
|
||||||
|
# for scales and qzeros:
|
||||||
|
# loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa
|
||||||
loaded_weight = loaded_weight.reshape(
|
loaded_weight = loaded_weight.reshape(
|
||||||
-1, self.config.mamba_d_state)
|
loaded_weight.shape[0], self.config.mamba_num_heads, -1)
|
||||||
elif ".D" in name:
|
gate_weight, hidden_states_weight = loaded_weight.chunk(2,
|
||||||
loaded_weight = loaded_weight[:, None].expand(
|
dim=-1)
|
||||||
-1, self.config.hidden_size_per_head)
|
gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1)
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
hidden_states_weight = hidden_states_weight.reshape(
|
||||||
|
loaded_weight.shape[0], -1)
|
||||||
|
loaded_weight = torch.cat([gate_weight, hidden_states_weight],
|
||||||
|
dim=-1)
|
||||||
|
if "mixer.in_proj.weight" in name:
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
|
||||||
# Offset parameter with vllm's RMSNorm haven't been supported yet.
|
# Offset parameter with vllm's RMSNorm haven't been supported yet.
|
||||||
if ".pre_mixer_norm" in name:
|
if ".pre_mixer_norm" in name:
|
||||||
loaded_weight += 1.0
|
loaded_weight += 1.0
|
||||||
@ -732,6 +879,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
|
|||||||
elif "model.norm.weight" in name:
|
elif "model.norm.weight" in name:
|
||||||
loaded_weight += 1.0
|
loaded_weight += 1.0
|
||||||
|
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user