[MODEL] FalconH1 (#18406)

Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae>
Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
Dhia Eddine Rhaiem 2025-05-21 15:59:06 +04:00 committed by GitHub
parent 61acfc45bc
commit eca18691d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 798 additions and 59 deletions

View File

@ -392,6 +392,11 @@ Specified using `--task generate`.
* `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc.
* ✅︎
* ✅︎
- * `FalconH1ForCausalLM`
* Falcon-H1
* `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc.
* ✅︎
* ✅︎
- * `GemmaForCausalLM`
* Gemma
* `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc.

View File

@ -147,6 +147,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
is_available_online=False,
min_transformers_version="4.52.2"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),

View File

@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
def __init__(self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
assert self.full_hidden_size % self.tp_size== 0,\
"Tensor parallel world size must divide hidden size."
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x
if self.n_groups == 1:
if self.tp_size > 1:
@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = (global_sums / count)
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
@ -106,6 +118,9 @@ class Mixer2RMSNormGated(CustomOp):
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if not self.use_rms_norm:
return x * nn.functional.silu(gate.to(torch.float32))
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)
@ -124,7 +139,7 @@ class Mixer2RMSNormGated(CustomOp):
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
@ -182,13 +197,15 @@ def mamba_v2_sharded_weight_loader(
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
param.data[
boundary:(boundary + take), # type: ignore[misc]
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc]
loaded_start_idx + take)] # type: ignore[misc]
boundary:(boundary + take),
... # type: ignore[misc]
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
take) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries
boundary += shard_size
loaded_boundary += (full_dim - extra)
loaded_boundary += full_dim - extra
return loader
@ -206,19 +223,22 @@ class MambaMixer2(CustomOp):
**selective** state spaces)
"""
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation="silu",
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# For TP, the sharding plan is as follows:
@ -238,17 +258,16 @@ class MambaMixer2(CustomOp):
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."
assert (num_heads % self.tp_size == 0
), "Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
(
"If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1."
)
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
"If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1.")
assert self.tp_size == 1 or quant_config is None, \
"Tensor parallel currently not supported for quantized models."
assert (
self.tp_size == 1 or quant_config is None
), "Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size
self.activation = activation
@ -265,8 +284,7 @@ class MambaMixer2(CustomOp):
self.n_groups = n_groups + extra_groups_for_head_shards(
n_groups, self.tp_size)
self.conv_dim = (intermediate_size +
2 * self.n_groups * ssm_state_size)
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
@ -279,11 +297,12 @@ class MambaMixer2(CustomOp):
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size +
self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config)
self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
@ -305,7 +324,8 @@ class MambaMixer2(CustomOp):
# - ditto for the otther two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias, {
self.conv1d.bias,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
[
@ -316,18 +336,25 @@ class MambaMixer2(CustomOp):
self.tp_size,
tp_rank,
)
})
},
)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight, {
self.conv1d.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader([
intermediate_settings,
group_shard_settings,
group_shard_settings,
], self.tp_size, tp_rank)
})
mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
if quant_config is None:
# - quant layers do not have a weight loader
@ -345,8 +372,10 @@ class MambaMixer2(CustomOp):
head_setings, # for dt
],
self.tp_size,
tp_rank)
})
tp_rank,
)
},
)
# - these are TPed by heads to reduce the size of the
# temporal shape
@ -357,6 +386,7 @@ class MambaMixer2(CustomOp):
))
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
@ -365,18 +395,25 @@ class MambaMixer2(CustomOp):
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=use_bias,
input_is_parallel=True,
quant_config=quant_config)
self.out_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=use_bias,
input_is_parallel=True,
quant_config=quant_config,
)
self.norm = Mixer2RMSNormGated(intermediate_size,
n_groups,
self.use_rms_norm,
eps=rms_norm_eps)
def forward_native(self, hidden_states: torch.Tensor,
conv_state: torch.Tensor, ssm_state: torch.Tensor):
def forward_native(
self,
hidden_states: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
pass
def forward_cuda(
@ -384,6 +421,7 @@ class MambaMixer2(CustomOp):
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
@ -401,6 +439,10 @@ class MambaMixer2(CustomOp):
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
gate, hidden_states_B_C, dt = torch.split(
projected_states,
[
@ -561,6 +603,9 @@ class MambaMixer2(CustomOp):
hidden_states = torch.vstack(ssd_output_list)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate)
# 5. Final linear projection

View File

@ -0,0 +1,685 @@
# SPDX-License-Identifier: Apache-2.0
"""Inference-only FalconH1 model."""
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from transformers import FalconH1Config
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class FalconH1MLP(nn.Module):
def __init__(
self,
config: FalconH1Config,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size = config.intermediate_size
self.gate_multiplier, self.down_multiplier = config.mlp_multipliers
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x[:, :self.intermediate_size // self.tp_size] *= self.gate_multiplier
x = self.act_fn(x)
x, _ = self.down_proj(x)
x = x * self.down_multiplier
return x
class FalconH1SSMDecoderLayer(nn.Module):
def __init__(
self,
config: FalconH1Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.d_ssm = (int(config.mamba_expand * config.hidden_size)
if config.mamba_d_ssm is None else config.mamba_d_ssm)
self.mamba = MambaMixer2(
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=self.d_ssm,
use_conv_bias=config.mamba_conv_bias,
use_bias=config.mamba_proj_bias,
n_groups=config.mamba_n_groups,
num_heads=config.mamba_n_heads,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config,
use_rms_norm=config.mamba_rms_norm,
)
# n_groups is overridden later by `MambaMixer2`
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
self.zxbcdt_multipliers = config.ssm_multipliers
self._init_mup_vector()
def _init_mup_vector(self):
"""
Non learnable per-block scaling vector composed of element-wise
multipliersapplied to each separate contiguous block of the output
of the linear projection (in_proj) before further processing
(gating, convolution, SSM):
- Z block: [0 : d_ssm] zxbcdt_multipliers[0]
- X block: [d_ssm : 2 * d_ssm] zxbcdt_multipliers[1]
- B block: [2 * d_ssm : 2 * d_ssm + G * S] zxbcdt_multipliers[2]
- C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
zxbcdt_multipliers[3]
- dt block: [2 * d_ssm + 2 * G * S : end] zxbcdt_multipliers[4]
where:
- d_ssm: Dimension of state-space model latent
- G: Number of groups (n_groups)
- S: SSM state size per group
- All indices are divided by tp_size to support tensor parallelism
"""
vector_shape = (2 * self.d_ssm + 2 * self.groups_time_state_size +
self.config.mamba_n_heads) // self.tp_size
mup_vector = torch.ones(1, vector_shape)
# Z vector 0 -> d_ssm
mup_vector[:, :self.d_ssm //
self.tp_size] *= self.zxbcdt_multipliers[0]
# X vector d_ssm -> 2 * d_ssm
mup_vector[:,
(self.d_ssm //
self.tp_size):(2 * self.d_ssm //
self.tp_size)] *= self.zxbcdt_multipliers[1]
# B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
mup_vector[
:,
(2 * self.d_ssm) //
self.tp_size:(2 * self.d_ssm + self.groups_time_state_size) //
self.tp_size,
] *= self.zxbcdt_multipliers[2]
# C vector 2 * d_ssm + (n_group * d_state)
# -> 2 * d_ssm + 2 * (n_group * d_state)
mup_vector[
:,
(2 * self.d_ssm + self.groups_time_state_size) //
self.tp_size:(2 * self.d_ssm + 2 * self.groups_time_state_size) //
self.tp_size,
] *= self.zxbcdt_multipliers[3]
# dt vector 2 * d_ssm + 2 * (n_group * d_state)
# -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
mup_vector[
:,
(2 * self.d_ssm + 2 * self.groups_time_state_size) //
self.tp_size:,
] *= self.zxbcdt_multipliers[4]
self.register_buffer("mup_vector", mup_vector, persistent=False)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
hidden_states = self.mamba(
hidden_states,
mamba_cache_params,
mamba2_metadata=mamba2_metadata,
mup_vector=self.mup_vector,
)
return hidden_states, residual
class FalconH1AttentionDecoderLayer(nn.Module):
def __init__(
self,
config: FalconH1Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
rope_theta = getattr(config, "rope_theta", 1e11)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = (config.hidden_size // self.total_num_heads if getattr(
config, "head_dim", None) is None else config.head_dim)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "partial_rotary_factor"):
rotary_dim = self.head_dim * config.partial_rotary_factor
elif hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
rotary_dim = self.head_dim # default
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_scaling=rope_scaling,
base=rope_theta,
is_neox_style=True,
dtype=None, # see impl of get_rope
)
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
self.key_multiplier = config.key_multiplier
def self_attention(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k = k * self.key_multiplier
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
):
hidden_states = self.self_attention(
positions=positions,
hidden_states=hidden_states,
)
return hidden_states, residual
class FalconH1ParallelHybrid(nn.Module):
"""
A hybrid decoder layer for FalconH1 where the input is processed
in parallel through both the self-attention branch and the SSM (Mamba)
branch. Their outputs are then summed to produce the final hidden state.
This layer uses:
- FalconH1AttentionDecoderLayer for the multi-head self-attention branch.
- FalconH1SSMDecoderLayer for the state-space (Mamba) branch.
"""
def __init__(
self,
config: FalconH1Config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
# Instantiate the attention branch
self.self_attn = FalconH1AttentionDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
# Instantiate the SSM branch
self.mamba = FalconH1SSMDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
)
self.ssm_out_multiplier = config.ssm_out_multiplier
self.ssm_in_multiplier = config.ssm_in_multiplier
self.attention_in_multiplier = config.attention_in_multiplier
self.attn_out_multiplier = config.attention_out_multiplier
self.feed_forward = FalconH1MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Process input through the attention branch.
# FalconH1AttentionDecoderLayer expects positions, hidden_states,
# kv_cache, attn_metadata, and residual.
attn_hidden, _ = self.self_attn(
positions=positions,
hidden_states=hidden_states * self.attention_in_multiplier,
residual=residual,
**kwargs,
)
# Process input through the SSM branch.
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
# residual, mamba_cache_params, and sequence_idx.
ssm_hidden, _ = self.mamba(
hidden_states=hidden_states * self.ssm_in_multiplier,
residual=residual,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
**kwargs,
)
# Sum the outputs from both branches.
# We assume both branches produce outputs of the same
# dimensionality (config.hidden_size).
hidden_states = (attn_hidden * self.attn_out_multiplier) + (
ssm_hidden * self.ssm_out_multiplier)
hidden_states = hidden_states + residual
# feed-forward
residual = hidden_states
hidden_states = self.pre_ff_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class FalconH1Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: FalconH1Config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.embedding_multiplier = config.embedding_multiplier
else:
self.embed_tokens = PPMissingLayer()
self.embedding_multiplier = 1.0
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = FalconH1ParallelHybrid
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
else:
self.final_layernorm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds * self.embedding_multiplier
else:
hidden_states = (self.get_input_embeddings(input_ids) *
self.embedding_multiplier)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert (not cache_config.enable_prefix_caching
), "FalconH1 currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = FalconH1Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.tie_word_embeddings = config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
self.mamba_cache: Optional[MambaCacheManager] = None
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
)
self.lm_head_multiplier = config.lm_head_multiplier
if self.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size,
config.vocab_size,
scale=config.lm_head_multiplier,
)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
if self.mamba_cache is None:
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(
input_ids,
positions,
mamba_cache_params,
intermediate_tensors,
inputs_embeds,
)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape, temporal_state_shape = None, None
intermediate_size = (int(self.config.mamba_expand *
hidden_size) if self.config.mamba_d_ssm
is None else self.config.mamba_d_ssm)
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards(
self.config.mamba_n_groups, world_size)
# - heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state
conv_state_shape = (
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.mamba_n_heads, world_size),
self.config.mamba_d_head,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if "mamba" in name:
name = name.replace("mamba", "mamba.mamba")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if self.tie_word_embeddings and "lm_head" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.tie_word_embeddings:
loaded_params.add("lm_head.weight")
return loaded_params

View File

@ -79,6 +79,7 @@ _TEXT_GENERATION_MODELS = {
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),