mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 09:42:14 +08:00
[MODEL] FalconH1 (#18406)
Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae> Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
parent
61acfc45bc
commit
eca18691d2
@ -392,6 +392,11 @@ Specified using `--task generate`.
|
|||||||
* `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc.
|
* `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `FalconH1ForCausalLM`
|
||||||
|
* Falcon-H1
|
||||||
|
* `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc.
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
- * `GemmaForCausalLM`
|
- * `GemmaForCausalLM`
|
||||||
* Gemma
|
* Gemma
|
||||||
* `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc.
|
* `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc.
|
||||||
|
|||||||
@ -147,6 +147,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
||||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
|
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
|
||||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||||
|
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
|
||||||
|
is_available_online=False,
|
||||||
|
min_transformers_version="4.52.2"),
|
||||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||||
|
|||||||
@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
@CustomOp.register("mixer2_gated_rms_norm")
|
@CustomOp.register("mixer2_gated_rms_norm")
|
||||||
class Mixer2RMSNormGated(CustomOp):
|
class Mixer2RMSNormGated(CustomOp):
|
||||||
|
|
||||||
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
|
def __init__(self,
|
||||||
|
full_hidden_size: int,
|
||||||
|
full_n_groups: int,
|
||||||
|
use_rms_norm: bool = True,
|
||||||
|
eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
|
|||||||
self.n_groups = full_hidden_size // self.group_size
|
self.n_groups = full_hidden_size // self.group_size
|
||||||
|
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
self.use_rms_norm = use_rms_norm
|
||||||
set_weight_attrs(self.weight,
|
if self.use_rms_norm:
|
||||||
{"weight_loader": sharded_weight_loader(0)})
|
# Register norm weight only if we're actually applying RMSNorm
|
||||||
assert self.full_hidden_size % self.tp_size== 0,\
|
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||||
"Tensor parallel world size must divide hidden size."
|
set_weight_attrs(self.weight,
|
||||||
|
{"weight_loader": sharded_weight_loader(0)})
|
||||||
|
else:
|
||||||
|
# Avoid checkpoint mismatch by skipping unused parameter
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
assert (self.full_hidden_size % self.tp_size == 0
|
||||||
|
), "Tensor parallel world size must divide hidden size."
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
|
|||||||
# the input and then redundantly compute the RMSNorm.
|
# the input and then redundantly compute the RMSNorm.
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||||
|
if not self.use_rms_norm:
|
||||||
|
return x
|
||||||
|
|
||||||
if self.n_groups == 1:
|
if self.n_groups == 1:
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
|
|||||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||||
# Calculate the variance
|
# Calculate the variance
|
||||||
count = self.tp_size * x.shape[-1]
|
count = self.tp_size * x.shape[-1]
|
||||||
variance = (global_sums / count)
|
variance = global_sums / count
|
||||||
|
|
||||||
else:
|
else:
|
||||||
variance = x.pow(2).mean(-1, keepdim=True)
|
variance = x.pow(2).mean(-1, keepdim=True)
|
||||||
@ -106,6 +118,9 @@ class Mixer2RMSNormGated(CustomOp):
|
|||||||
gate: torch.Tensor,
|
gate: torch.Tensor,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
|
||||||
|
if not self.use_rms_norm:
|
||||||
|
return x * nn.functional.silu(gate.to(torch.float32))
|
||||||
|
|
||||||
if self.tp_size > 1 or self.n_groups != 1:
|
if self.tp_size > 1 or self.n_groups != 1:
|
||||||
return self.forward_native(x, gate)
|
return self.forward_native(x, gate)
|
||||||
|
|
||||||
@ -182,13 +197,15 @@ def mamba_v2_sharded_weight_loader(
|
|||||||
# seem to handle slices well.
|
# seem to handle slices well.
|
||||||
# https://github.com/python/mypy/issues/2410
|
# https://github.com/python/mypy/issues/2410
|
||||||
param.data[
|
param.data[
|
||||||
boundary:(boundary + take), # type: ignore[misc]
|
boundary:(boundary + take),
|
||||||
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc]
|
... # type: ignore[misc]
|
||||||
loaded_start_idx + take)] # type: ignore[misc]
|
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
|
||||||
|
take) # type: ignore[misc]
|
||||||
|
] # type: ignore[misc]
|
||||||
|
|
||||||
# move indexing boundaries
|
# move indexing boundaries
|
||||||
boundary += shard_size
|
boundary += shard_size
|
||||||
loaded_boundary += (full_dim - extra)
|
loaded_boundary += full_dim - extra
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
@ -206,19 +223,22 @@ class MambaMixer2(CustomOp):
|
|||||||
**selective** state spaces)
|
**selective** state spaces)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
hidden_size: int,
|
self,
|
||||||
ssm_state_size: int,
|
hidden_size: int,
|
||||||
conv_kernel_size: int,
|
ssm_state_size: int,
|
||||||
intermediate_size: int,
|
conv_kernel_size: int,
|
||||||
use_conv_bias: bool,
|
intermediate_size: int,
|
||||||
use_bias: bool,
|
use_conv_bias: bool,
|
||||||
n_groups: int = 1,
|
use_bias: bool,
|
||||||
num_heads: int = 128,
|
n_groups: int = 1,
|
||||||
head_dim: int = 64,
|
num_heads: int = 128,
|
||||||
rms_norm_eps: float = 1e-5,
|
head_dim: int = 64,
|
||||||
activation="silu",
|
rms_norm_eps: float = 1e-5,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
activation: str = "silu",
|
||||||
|
use_rms_norm: bool = True,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# For TP, the sharding plan is as follows:
|
# For TP, the sharding plan is as follows:
|
||||||
@ -238,17 +258,16 @@ class MambaMixer2(CustomOp):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
assert num_heads % self.tp_size == 0, \
|
assert (num_heads % self.tp_size == 0
|
||||||
"Tensor parallel world size must divide num heads."
|
), "Tensor parallel world size must divide num heads."
|
||||||
|
|
||||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
||||||
(
|
"If tensor parallel world size does not divide num_heads, "
|
||||||
"If tensor parallel world size does not divide num_heads, "
|
"then num_groups must equal 1.")
|
||||||
"then num_groups must equal 1."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.tp_size == 1 or quant_config is None, \
|
assert (
|
||||||
"Tensor parallel currently not supported for quantized models."
|
self.tp_size == 1 or quant_config is None
|
||||||
|
), "Tensor parallel currently not supported for quantized models."
|
||||||
|
|
||||||
self.ssm_state_size = ssm_state_size
|
self.ssm_state_size = ssm_state_size
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
@ -265,8 +284,7 @@ class MambaMixer2(CustomOp):
|
|||||||
self.n_groups = n_groups + extra_groups_for_head_shards(
|
self.n_groups = n_groups + extra_groups_for_head_shards(
|
||||||
n_groups, self.tp_size)
|
n_groups, self.tp_size)
|
||||||
|
|
||||||
self.conv_dim = (intermediate_size +
|
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
|
||||||
2 * self.n_groups * ssm_state_size)
|
|
||||||
self.conv1d = ColumnParallelLinear(
|
self.conv1d = ColumnParallelLinear(
|
||||||
input_size=conv_kernel_size,
|
input_size=conv_kernel_size,
|
||||||
output_size=self.conv_dim,
|
output_size=self.conv_dim,
|
||||||
@ -279,11 +297,12 @@ class MambaMixer2(CustomOp):
|
|||||||
# doesn't allow to override it
|
# doesn't allow to override it
|
||||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||||
|
|
||||||
self.in_proj = ColumnParallelLinear(input_size=hidden_size,
|
self.in_proj = ColumnParallelLinear(
|
||||||
output_size=intermediate_size +
|
input_size=hidden_size,
|
||||||
self.conv_dim + self.num_heads,
|
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
# - because in_proj is a concatenation of 3 weights, we
|
# - because in_proj is a concatenation of 3 weights, we
|
||||||
# need to interleave them before sharding
|
# need to interleave them before sharding
|
||||||
@ -305,7 +324,8 @@ class MambaMixer2(CustomOp):
|
|||||||
# - ditto for the otther two weights below
|
# - ditto for the otther two weights below
|
||||||
delattr(self.conv1d.bias, "weight_loader")
|
delattr(self.conv1d.bias, "weight_loader")
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
self.conv1d.bias, {
|
self.conv1d.bias,
|
||||||
|
{
|
||||||
"weight_loader":
|
"weight_loader":
|
||||||
mamba_v2_sharded_weight_loader(
|
mamba_v2_sharded_weight_loader(
|
||||||
[
|
[
|
||||||
@ -316,18 +336,25 @@ class MambaMixer2(CustomOp):
|
|||||||
self.tp_size,
|
self.tp_size,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
)
|
)
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
delattr(self.conv1d.weight, "weight_loader")
|
delattr(self.conv1d.weight, "weight_loader")
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
self.conv1d.weight, {
|
self.conv1d.weight,
|
||||||
|
{
|
||||||
"weight_loader":
|
"weight_loader":
|
||||||
mamba_v2_sharded_weight_loader([
|
mamba_v2_sharded_weight_loader(
|
||||||
intermediate_settings,
|
[
|
||||||
group_shard_settings,
|
intermediate_settings,
|
||||||
group_shard_settings,
|
group_shard_settings,
|
||||||
], self.tp_size, tp_rank)
|
group_shard_settings,
|
||||||
})
|
],
|
||||||
|
self.tp_size,
|
||||||
|
tp_rank,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
# - quant layers do not have a weight loader
|
# - quant layers do not have a weight loader
|
||||||
@ -345,8 +372,10 @@ class MambaMixer2(CustomOp):
|
|||||||
head_setings, # for dt
|
head_setings, # for dt
|
||||||
],
|
],
|
||||||
self.tp_size,
|
self.tp_size,
|
||||||
tp_rank)
|
tp_rank,
|
||||||
})
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# - these are TPed by heads to reduce the size of the
|
# - these are TPed by heads to reduce the size of the
|
||||||
# temporal shape
|
# temporal shape
|
||||||
@ -357,6 +386,7 @@ class MambaMixer2(CustomOp):
|
|||||||
))
|
))
|
||||||
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||||
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||||
|
self.use_rms_norm = use_rms_norm
|
||||||
|
|
||||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||||
a_weight_loader = composed_weight_loader(
|
a_weight_loader = composed_weight_loader(
|
||||||
@ -365,18 +395,25 @@ class MambaMixer2(CustomOp):
|
|||||||
set_weight_attrs(self.dt_bias,
|
set_weight_attrs(self.dt_bias,
|
||||||
{"weight_loader": sharded_weight_loader(0)})
|
{"weight_loader": sharded_weight_loader(0)})
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(intermediate_size,
|
self.out_proj = RowParallelLinear(
|
||||||
hidden_size,
|
intermediate_size,
|
||||||
bias=use_bias,
|
hidden_size,
|
||||||
input_is_parallel=True,
|
bias=use_bias,
|
||||||
quant_config=quant_config)
|
input_is_parallel=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = Mixer2RMSNormGated(intermediate_size,
|
self.norm = Mixer2RMSNormGated(intermediate_size,
|
||||||
n_groups,
|
n_groups,
|
||||||
|
self.use_rms_norm,
|
||||||
eps=rms_norm_eps)
|
eps=rms_norm_eps)
|
||||||
|
|
||||||
def forward_native(self, hidden_states: torch.Tensor,
|
def forward_native(
|
||||||
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
conv_state: torch.Tensor,
|
||||||
|
ssm_state: torch.Tensor,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
@ -384,6 +421,7 @@ class MambaMixer2(CustomOp):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
mamba2_metadata: Mamba2Metadata,
|
mamba2_metadata: Mamba2Metadata,
|
||||||
|
mup_vector: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||||
# kernels to operate in continuous batching and in chunked prefill
|
# kernels to operate in continuous batching and in chunked prefill
|
||||||
@ -401,6 +439,10 @@ class MambaMixer2(CustomOp):
|
|||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states, _ = self.in_proj(hidden_states)
|
projected_states, _ = self.in_proj(hidden_states)
|
||||||
|
|
||||||
|
if mup_vector is not None:
|
||||||
|
projected_states = projected_states * mup_vector
|
||||||
|
|
||||||
gate, hidden_states_B_C, dt = torch.split(
|
gate, hidden_states_B_C, dt = torch.split(
|
||||||
projected_states,
|
projected_states,
|
||||||
[
|
[
|
||||||
@ -561,6 +603,9 @@ class MambaMixer2(CustomOp):
|
|||||||
hidden_states = torch.vstack(ssd_output_list)
|
hidden_states = torch.vstack(ssd_output_list)
|
||||||
|
|
||||||
# 4. gated MLP
|
# 4. gated MLP
|
||||||
|
# GatedRMSNorm internally applying SiLU to the gate
|
||||||
|
# SiLU is applied internally before normalization, unlike standard
|
||||||
|
# norm usage
|
||||||
hidden_states = self.norm(hidden_states, gate)
|
hidden_states = self.norm(hidden_states, gate)
|
||||||
|
|
||||||
# 5. Final linear projection
|
# 5. Final linear projection
|
||||||
|
|||||||
685
vllm/model_executor/models/falcon_h1.py
Normal file
685
vllm/model_executor/models/falcon_h1.py
Normal file
@ -0,0 +1,685 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Inference-only FalconH1 model."""
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import FalconH1Config
|
||||||
|
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
|
MambaMixer2, extra_groups_for_head_shards)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||||
|
MambaCacheParams)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||||
|
SupportsV0Only)
|
||||||
|
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
|
maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
class FalconH1MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: FalconH1Config,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=config.hidden_size,
|
||||||
|
output_sizes=[config.intermediate_size] * 2,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
input_size=config.intermediate_size,
|
||||||
|
output_size=config.hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_multiplier, self.down_multiplier = config.mlp_multipliers
|
||||||
|
if config.hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, _ = self.gate_up_proj(x)
|
||||||
|
x[:, :self.intermediate_size // self.tp_size] *= self.gate_multiplier
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
x = x * self.down_multiplier
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FalconH1SSMDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: FalconH1Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.d_ssm = (int(config.mamba_expand * config.hidden_size)
|
||||||
|
if config.mamba_d_ssm is None else config.mamba_d_ssm)
|
||||||
|
|
||||||
|
self.mamba = MambaMixer2(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
ssm_state_size=config.mamba_d_state,
|
||||||
|
conv_kernel_size=config.mamba_d_conv,
|
||||||
|
intermediate_size=self.d_ssm,
|
||||||
|
use_conv_bias=config.mamba_conv_bias,
|
||||||
|
use_bias=config.mamba_proj_bias,
|
||||||
|
n_groups=config.mamba_n_groups,
|
||||||
|
num_heads=config.mamba_n_heads,
|
||||||
|
head_dim=config.mamba_d_head,
|
||||||
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
|
activation=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_rms_norm=config.mamba_rms_norm,
|
||||||
|
)
|
||||||
|
# n_groups is overridden later by `MambaMixer2`
|
||||||
|
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
||||||
|
self.zxbcdt_multipliers = config.ssm_multipliers
|
||||||
|
self._init_mup_vector()
|
||||||
|
|
||||||
|
def _init_mup_vector(self):
|
||||||
|
"""
|
||||||
|
Non learnable per-block scaling vector composed of element-wise
|
||||||
|
multipliersapplied to each separate contiguous block of the output
|
||||||
|
of the linear projection (in_proj) before further processing
|
||||||
|
(gating, convolution, SSM):
|
||||||
|
|
||||||
|
- Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
|
||||||
|
- X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
|
||||||
|
- B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
|
||||||
|
- C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
|
||||||
|
→ zxbcdt_multipliers[3]
|
||||||
|
- dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
|
||||||
|
|
||||||
|
where:
|
||||||
|
- d_ssm: Dimension of state-space model latent
|
||||||
|
- G: Number of groups (n_groups)
|
||||||
|
- S: SSM state size per group
|
||||||
|
- All indices are divided by tp_size to support tensor parallelism
|
||||||
|
"""
|
||||||
|
vector_shape = (2 * self.d_ssm + 2 * self.groups_time_state_size +
|
||||||
|
self.config.mamba_n_heads) // self.tp_size
|
||||||
|
mup_vector = torch.ones(1, vector_shape)
|
||||||
|
# Z vector 0 -> d_ssm
|
||||||
|
mup_vector[:, :self.d_ssm //
|
||||||
|
self.tp_size] *= self.zxbcdt_multipliers[0]
|
||||||
|
# X vector d_ssm -> 2 * d_ssm
|
||||||
|
mup_vector[:,
|
||||||
|
(self.d_ssm //
|
||||||
|
self.tp_size):(2 * self.d_ssm //
|
||||||
|
self.tp_size)] *= self.zxbcdt_multipliers[1]
|
||||||
|
# B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
|
||||||
|
mup_vector[
|
||||||
|
:,
|
||||||
|
(2 * self.d_ssm) //
|
||||||
|
self.tp_size:(2 * self.d_ssm + self.groups_time_state_size) //
|
||||||
|
self.tp_size,
|
||||||
|
] *= self.zxbcdt_multipliers[2]
|
||||||
|
# C vector 2 * d_ssm + (n_group * d_state)
|
||||||
|
# -> 2 * d_ssm + 2 * (n_group * d_state)
|
||||||
|
mup_vector[
|
||||||
|
:,
|
||||||
|
(2 * self.d_ssm + self.groups_time_state_size) //
|
||||||
|
self.tp_size:(2 * self.d_ssm + 2 * self.groups_time_state_size) //
|
||||||
|
self.tp_size,
|
||||||
|
] *= self.zxbcdt_multipliers[3]
|
||||||
|
# dt vector 2 * d_ssm + 2 * (n_group * d_state)
|
||||||
|
# -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
|
||||||
|
mup_vector[
|
||||||
|
:,
|
||||||
|
(2 * self.d_ssm + 2 * self.groups_time_state_size) //
|
||||||
|
self.tp_size:,
|
||||||
|
] *= self.zxbcdt_multipliers[4]
|
||||||
|
|
||||||
|
self.register_buffer("mup_vector", mup_vector, persistent=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
mamba2_metadata: Mamba2Metadata,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
hidden_states = self.mamba(
|
||||||
|
hidden_states,
|
||||||
|
mamba_cache_params,
|
||||||
|
mamba2_metadata=mamba2_metadata,
|
||||||
|
mup_vector=self.mup_vector,
|
||||||
|
)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FalconH1AttentionDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: FalconH1Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
rope_theta = getattr(config, "rope_theta", 1e11)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = config.num_key_value_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = (config.hidden_size // self.total_num_heads if getattr(
|
||||||
|
config, "head_dim", None) is None else config.head_dim)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
if hasattr(config, "partial_rotary_factor"):
|
||||||
|
rotary_dim = self.head_dim * config.partial_rotary_factor
|
||||||
|
elif hasattr(config, "attn_rotary_emb"):
|
||||||
|
rotary_dim = config.attn_rotary_emb # for backward compatibility
|
||||||
|
else:
|
||||||
|
rotary_dim = self.head_dim # default
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
head_size=self.head_dim,
|
||||||
|
rotary_dim=rotary_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
base=rope_theta,
|
||||||
|
is_neox_style=True,
|
||||||
|
dtype=None, # see impl of get_rope
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
)
|
||||||
|
self.key_multiplier = config.key_multiplier
|
||||||
|
|
||||||
|
def self_attention(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
k = k * self.key_multiplier
|
||||||
|
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
hidden_states = self.self_attention(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FalconH1ParallelHybrid(nn.Module):
|
||||||
|
"""
|
||||||
|
A hybrid decoder layer for FalconH1 where the input is processed
|
||||||
|
in parallel through both the self-attention branch and the SSM (Mamba)
|
||||||
|
branch. Their outputs are then summed to produce the final hidden state.
|
||||||
|
|
||||||
|
This layer uses:
|
||||||
|
- FalconH1AttentionDecoderLayer for the multi-head self-attention branch.
|
||||||
|
- FalconH1SSMDecoderLayer for the state-space (Mamba) branch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: FalconH1Config,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# Instantiate the attention branch
|
||||||
|
self.self_attn = FalconH1AttentionDecoderLayer(
|
||||||
|
config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
# Instantiate the SSM branch
|
||||||
|
self.mamba = FalconH1SSMDecoderLayer(
|
||||||
|
config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.ssm_out_multiplier = config.ssm_out_multiplier
|
||||||
|
self.ssm_in_multiplier = config.ssm_in_multiplier
|
||||||
|
|
||||||
|
self.attention_in_multiplier = config.attention_in_multiplier
|
||||||
|
self.attn_out_multiplier = config.attention_out_multiplier
|
||||||
|
|
||||||
|
self.feed_forward = FalconH1MLP(config)
|
||||||
|
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
mamba2_metadata: Mamba2Metadata,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
# Process input through the attention branch.
|
||||||
|
# FalconH1AttentionDecoderLayer expects positions, hidden_states,
|
||||||
|
# kv_cache, attn_metadata, and residual.
|
||||||
|
attn_hidden, _ = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states * self.attention_in_multiplier,
|
||||||
|
residual=residual,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process input through the SSM branch.
|
||||||
|
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
|
||||||
|
# residual, mamba_cache_params, and sequence_idx.
|
||||||
|
ssm_hidden, _ = self.mamba(
|
||||||
|
hidden_states=hidden_states * self.ssm_in_multiplier,
|
||||||
|
residual=residual,
|
||||||
|
mamba_cache_params=mamba_cache_params,
|
||||||
|
mamba2_metadata=mamba2_metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Sum the outputs from both branches.
|
||||||
|
# We assume both branches produce outputs of the same
|
||||||
|
# dimensionality (config.hidden_size).
|
||||||
|
hidden_states = (attn_hidden * self.attn_out_multiplier) + (
|
||||||
|
ssm_hidden * self.ssm_out_multiplier)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
# feed-forward
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.pre_ff_layernorm(hidden_states)
|
||||||
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FalconH1Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config: FalconH1Config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
)
|
||||||
|
self.embedding_multiplier = config.embedding_multiplier
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
self.embedding_multiplier = 1.0
|
||||||
|
|
||||||
|
def get_layer(prefix: str):
|
||||||
|
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||||
|
layer_class = FalconH1ParallelHybrid
|
||||||
|
return layer_class(
|
||||||
|
config,
|
||||||
|
layer_idx,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.final_layernorm = PPMissingLayer()
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# pass a sequence index tensor, that is required for
|
||||||
|
# proper continuous batching computation including
|
||||||
|
# chunked prefill
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
mamba2_metadata = prepare_mamba2_metadata(
|
||||||
|
chunk_size=self.config.mamba_chunk_size,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds * self.embedding_multiplier
|
||||||
|
else:
|
||||||
|
hidden_states = (self.get_input_embeddings(input_ids) *
|
||||||
|
self.embedding_multiplier)
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
|
||||||
|
hidden_states = layer(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
mamba_cache_params=layer_mamba_cache_params,
|
||||||
|
mamba2_metadata=mamba2_metadata,
|
||||||
|
)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
})
|
||||||
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||||
|
IsHybrid, SupportsV0Only):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
assert (not cache_config.enable_prefix_caching
|
||||||
|
), "FalconH1 currently does not support prefix caching"
|
||||||
|
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.model = FalconH1Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=(
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else
|
||||||
|
lora_config.lora_vocab_padding_size),
|
||||||
|
)
|
||||||
|
self.lm_head_multiplier = config.lm_head_multiplier
|
||||||
|
if self.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
|
self.model.embed_tokens)
|
||||||
|
# Used to track and store by the Mamba cache between steps.
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
|
scale=config.lm_head_multiplier,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if self.mamba_cache is None:
|
||||||
|
self.mamba_cache = MambaCacheManager(
|
||||||
|
self.vllm_config,
|
||||||
|
self.lm_head.weight.dtype
|
||||||
|
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
*self._get_mamba_cache_shape(),
|
||||||
|
)
|
||||||
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
mamba_cache_params,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||||
|
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||||
|
input_buffers, **kwargs)
|
||||||
|
|
||||||
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
|
def _get_mamba_cache_shape(
|
||||||
|
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||||
|
world_size = get_tensor_model_parallel_world_size()
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
|
||||||
|
conv_state_shape, temporal_state_shape = None, None
|
||||||
|
|
||||||
|
intermediate_size = (int(self.config.mamba_expand *
|
||||||
|
hidden_size) if self.config.mamba_d_ssm
|
||||||
|
is None else self.config.mamba_d_ssm)
|
||||||
|
|
||||||
|
# if n_groups is not divisible by world_size, need to extend the shards
|
||||||
|
# to ensure all groups needed by a head is sharded along with it
|
||||||
|
n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards(
|
||||||
|
self.config.mamba_n_groups, world_size)
|
||||||
|
|
||||||
|
# - heads and n_groups are TP-ed
|
||||||
|
conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state
|
||||||
|
conv_state_shape = (
|
||||||
|
divide(conv_dim, world_size),
|
||||||
|
self.config.mamba_d_conv - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are not TP-ed as they depend on A, dt_bias, D
|
||||||
|
# - they are typically small
|
||||||
|
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||||
|
temporal_state_shape = (
|
||||||
|
divide(self.config.mamba_n_heads, world_size),
|
||||||
|
self.config.mamba_d_head,
|
||||||
|
self.config.mamba_d_state,
|
||||||
|
)
|
||||||
|
return conv_state_shape, temporal_state_shape
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "A_log" in name:
|
||||||
|
name = name.replace("A_log", "A")
|
||||||
|
|
||||||
|
if "mamba" in name:
|
||||||
|
name = name.replace("mamba", "mamba.mamba")
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if self.tie_word_embeddings and "lm_head" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
if self.tie_word_embeddings:
|
||||||
|
loaded_params.add("lm_head.weight")
|
||||||
|
return loaded_params
|
||||||
@ -79,6 +79,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||||
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||||
|
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
|
||||||
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
|
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
|
||||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user