From 320feae6f506097c47b6b41a634a6197512cffc1 Mon Sep 17 00:00:00 2001 From: Paul Pak <52512091+paulpak58@users.noreply.github.com> Date: Wed, 8 Oct 2025 01:03:05 +0900 Subject: [PATCH] [Model] Lfm2Moe (#26344) Signed-off-by: Paul Pak --- docs/models/supported_models.md | 1 + tests/models/registry.py | 3 + vllm/model_executor/models/lfm2.py | 9 +- vllm/model_executor/models/lfm2_moe.py | 797 ++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/lfm2_moe.py | 160 ++++ 8 files changed, 967 insertions(+), 7 deletions(-) create mode 100644 vllm/model_executor/models/lfm2_moe.py create mode 100644 vllm/transformers_utils/configs/lfm2_moe.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 10ccd73d8f30..5ac8f2121f97 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -390,6 +390,7 @@ th { | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 297e8854c5bd..e7affb41565c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -321,6 +321,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Lfm2ForCausalLM": _HfExamplesInfo( "LiquidAI/LFM2-1.2B", min_transformers_version="4.54" ), + "Lfm2MoeForCausalLM": _HfExamplesInfo( + "LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58" + ), "LlamaForCausalLM": _HfExamplesInfo( "meta-llama/Llama-3.2-1B-Instruct", extras={ diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index ae5c97426ee7..425c93687760 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -71,14 +71,14 @@ class Lfm2MLP(nn.Module): output_sizes=[ff_dim] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", + prefix=f"{prefix}.w1", ) self.w2 = RowParallelLinear( input_size=ff_dim, output_size=dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.down_proj", + prefix=f"{prefix}.w2", ) self.act_fn = SiluAndMul() @@ -484,17 +484,12 @@ class Lfm2ForCausalLM( quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config - scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, ( "Lfm2 currently does not support prefix caching" ) super().__init__() self.config = config - self.vllm_config = vllm_config - self.scheduler_config = scheduler_config - self.model_config = vllm_config.model_config - self.model = Lfm2Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py new file mode 100644 index 000000000000..728bd90be117 --- /dev/null +++ b/vllm/model_executor/models/lfm2_moe.py @@ -0,0 +1,797 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Any, Optional + +import torch +import torch.nn as nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Lfm2MoeConfig + +from .interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + SupportsLoRA, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +class Lfm2MoeMlp(nn.Module): + def __init__( + self, + dim: int, + ff_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.w1", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.w2", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class Lfm2MoeSparseMoeBlock(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}." + ) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + if config.use_expert_bias: + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(self.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, # needed for softmax score func + num_expert_group=1, + topk_group=1, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = ( + self.experts(hidden_states=hidden_states, router_logits=router_logits) + * self.routed_scaling_factor + ) + + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + +class Lfm2MoeAttention(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class Lfm2MoeAttentionDecoderLayer(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + self.self_attn = Lfm2MoeAttention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if layer_idx < config.num_dense_layers: + self.feed_forward = Lfm2MoeMlp( + dim=config.hidden_size, + ff_dim=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = Lfm2MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + enable_eplb=enable_eplb, + ) + + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class Lfm2MoeShortConvDecoderLayer(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.hidden_size, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.conv", + ) + + if layer_idx < config.num_dense_layers: + self.feed_forward = Lfm2MoeMlp( + dim=config.hidden_size, + ff_dim=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = Lfm2MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + enable_eplb=enable_eplb, + ) + + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm(hidden_states, residual) + output = torch.empty_like(hidden_states) + self.conv( + hidden_states, + output, + ) + hidden_states, residual = self.ffn_norm(output, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Lfm2MoeModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config + enable_eplb = parallel_config.enable_eplb + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size + ) + + def get_layer(prefix: str): + layer_idx = extract_layer_index(prefix) + is_attn = self.config.layer_types[layer_idx] == "full_attention" + layer_class = ( + Lfm2MoeAttentionDecoderLayer + if is_attn + else Lfm2MoeShortConvDecoderLayer + ) + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer : self.end_layer]: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "expert_bias" in name: + name = name.replace("expert_bias", "gate.e_score_correction_bias") + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if ("feed_forward.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Lfm2MoeForCausalLM( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsQuant, + MixtureOfExperts, +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, ...]: + return MambaStateDtypeCalculator.short_conv_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int]]: + """Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.hidden_size, + conv_kernel=hf_config.conv_L_cache, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + assert not cache_config.enable_prefix_caching, ( + "Lfm2Moe currently does not support prefix caching" + ) + + super().__init__() + self.config = config + self.model = Lfm2MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size + ), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance( + layer, (Lfm2MoeAttentionDecoderLayer, Lfm2MoeShortConvDecoderLayer) + ) + if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock): + example_layer = layer.feed_forward + self.moe_layers.append(layer.feed_forward.experts) + + if example_layer is None: + raise RuntimeError( + "No Lfm2MoeSparseMoeBlock layer found in the model.layers." + ) + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock): + moe = layer.feed_forward + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c680d29923f8..191c5c0594dc 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -119,6 +119,7 @@ _TEXT_GENERATION_MODELS = { "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), + "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # For decapoda-research/llama-* diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ab3eb6de4780..b9592f08b9f2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -91,6 +91,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( step3_vl="Step3VLConfig", step3_text="Step3TextConfig", qwen3_next="Qwen3NextConfig", + lfm2_moe="Lfm2MoeConfig", ) _CONFIG_ATTRS_MAPPING: dict[str, str] = { diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 72c90e073131..6917123ce662 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -19,6 +19,7 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig +from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.midashenglm import MiDashengLMConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig @@ -46,6 +47,7 @@ __all__ = [ "EAGLEConfig", "RWConfig", "JAISConfig", + "Lfm2MoeConfig", "MedusaConfig", "MiDashengLMConfig", "MLPSpeculatorConfig", diff --git a/vllm/transformers_utils/configs/lfm2_moe.py b/vllm/transformers_utils/configs/lfm2_moe.py new file mode 100644 index 000000000000..7d17c2b4f74c --- /dev/null +++ b/vllm/transformers_utils/configs/lfm2_moe.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class Lfm2MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Lfm2MoeModel`]. It is used to instantiate a LFM2 Moe + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LFM2-8B-A1B model. + e.g. [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Lfm2Model`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 7168): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1792): + Intermediate size of the routed expert. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the conv layers. + conv_L_cache (`int`, *optional*, defaults to 3): + L_cache dim in the conv layers. + num_dense_layers (`int`, *optional*, defaults to 2): + Number of dense Lfm2MoeMLP layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 32): + Number of routed experts. + use_expert_bias (`bool`, *optional*, defaults to `True`): + Whether to use the expert bias on the routing weights. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for routed experts in MoE models. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + layer_types (`Optional`, *optional*): + Type of each layers. + + ```python + >>> from transformers import Lfm2MoeModel, Lfm2MoeConfig + + >>> # Initializing a LFM2 Moe model + >>> configuration = Lfm2MoeConfig() + + >>> # Initializing a model from the LFM2-8B-A1B style configuration + >>> model = Lfm2MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" # noqa: E501 + + model_type = "lfm2_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 65536, + hidden_size: int = 2048, + intermediate_size: int = 7168, + moe_intermediate_size: int = 1792, + num_hidden_layers: int = 32, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = True, + rope_theta: float = 1000000.0, + max_position_embeddings: int = 128_000, + use_cache: bool = True, + norm_eps: float = 0.00001, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + conv_bias: bool = False, + conv_L_cache: int = 3, + num_dense_layers: int = 2, + num_experts_per_tok: int = 4, + num_experts: int = 32, + use_expert_bias: bool = True, + routed_scaling_factor: float = 1.0, + norm_topk_prob: bool = True, + layer_types: Optional[list[str]] = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.norm_eps = norm_eps + + # attn operator config + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + # custom operator config + self.conv_bias = conv_bias + self.conv_L_cache = conv_L_cache + + # moe config + self.num_dense_layers = num_dense_layers + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.use_expert_bias = use_expert_bias + self.routed_scaling_factor = routed_scaling_factor + self.norm_topk_prob = norm_topk_prob + self.layer_types = layer_types + + tie_word_embeddings = kwargs.get( + "tie_embedding", tie_word_embeddings + ) # to fit original config keys + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Lfm2MoeConfig"]