mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 01:55:21 +08:00
[Model] FalconMamba Support (#9325)
This commit is contained in:
parent
496e991da8
commit
f6b97293aa
@ -87,6 +87,11 @@ Text Generation
|
|||||||
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`FalconMambaForCausalLM`
|
||||||
|
- FalconMamba
|
||||||
|
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
|
||||||
|
- ✅︎
|
||||||
|
-
|
||||||
* - :code:`GemmaForCausalLM`
|
* - :code:`GemmaForCausalLM`
|
||||||
- Gemma
|
- Gemma
|
||||||
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
|
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.worker.model_runner import _get_graph_batch_size
|
|||||||
|
|
||||||
from ...utils import check_outputs_equal
|
from ...utils import check_outputs_equal
|
||||||
|
|
||||||
MODELS = ["state-spaces/mamba-130m-hf"]
|
MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]
|
||||||
|
|
||||||
|
|
||||||
# Use lower-level interfaces to create this greedy generator, as mamba will
|
# Use lower-level interfaces to create this greedy generator, as mamba will
|
||||||
|
|||||||
@ -27,7 +27,6 @@ class RMSNorm(CustomOp):
|
|||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
self.variance_size_override = (None if var_hidden_size == hidden_size
|
self.variance_size_override = (None if var_hidden_size == hidden_size
|
||||||
else var_hidden_size)
|
else var_hidden_size)
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||||
@ -59,7 +59,7 @@ class MambaMixer(nn.Module):
|
|||||||
self.conv_kernel_size = config.conv_kernel
|
self.conv_kernel_size = config.conv_kernel
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.time_step_rank = int(config.time_step_rank)
|
self.time_step_rank = int(config.time_step_rank)
|
||||||
|
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
||||||
self.conv1d = ColumnParallelLinear(
|
self.conv1d = ColumnParallelLinear(
|
||||||
input_size=self.conv_kernel_size,
|
input_size=self.conv_kernel_size,
|
||||||
output_size=self.intermediate_size,
|
output_size=self.intermediate_size,
|
||||||
@ -109,6 +109,13 @@ class MambaMixer(nn.Module):
|
|||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
)
|
)
|
||||||
self.activation = config.hidden_act
|
self.activation = config.hidden_act
|
||||||
|
if self.is_falcon_mamba:
|
||||||
|
self.dt_layernorm = RMSNorm(self.time_step_rank,
|
||||||
|
eps=config.mixer_rms_eps)
|
||||||
|
self.b_layernorm = RMSNorm(self.ssm_state_size,
|
||||||
|
eps=config.mixer_rms_eps)
|
||||||
|
self.c_layernorm = RMSNorm(self.ssm_state_size,
|
||||||
|
eps=config.mixer_rms_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
@ -158,8 +165,12 @@ class MambaMixer(nn.Module):
|
|||||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
|
||||||
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
|
# but Mamba doesn't.
|
||||||
|
if self.is_falcon_mamba:
|
||||||
|
time_step = self.dt_layernorm(time_step.contiguous())
|
||||||
|
B = self.b_layernorm(B.contiguous())
|
||||||
|
C = self.c_layernorm(C.contiguous())
|
||||||
|
|
||||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||||
@ -213,11 +224,9 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
||||||
self.mixer = MambaMixer(config, layer_idx)
|
self.mixer = MambaMixer(config, layer_idx)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
|
||||||
eps=config.layer_norm_epsilon)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -319,8 +328,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
if lora_config:
|
if lora_config:
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
if config.tie_word_embeddings:
|
||||||
self.lm_head = self.backbone.embeddings
|
self.lm_head = self.backbone.embeddings
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Used to track and store by the Mamba cache between steps.
|
# Used to track and store by the Mamba cache between steps.
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||||
@ -398,7 +417,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "A_log" in name:
|
if "A_log" in name:
|
||||||
name = name.replace("A_log", "A")
|
name = name.replace("A_log", "A")
|
||||||
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
# For decapoda-research/llama-*
|
# For decapoda-research/llama-*
|
||||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||||
|
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user