mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 19:12:01 +08:00
parent
1e4ecca1d0
commit
320feae6f5
@ -390,6 +390,7 @@ th {
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -321,6 +321,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Lfm2ForCausalLM": _HfExamplesInfo(
|
||||
"LiquidAI/LFM2-1.2B", min_transformers_version="4.54"
|
||||
),
|
||||
"Lfm2MoeForCausalLM": _HfExamplesInfo(
|
||||
"LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58"
|
||||
),
|
||||
"LlamaForCausalLM": _HfExamplesInfo(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
extras={
|
||||
|
||||
@ -71,14 +71,14 @@ class Lfm2MLP(nn.Module):
|
||||
output_sizes=[ff_dim] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
prefix=f"{prefix}.w1",
|
||||
)
|
||||
self.w2 = RowParallelLinear(
|
||||
input_size=ff_dim,
|
||||
output_size=dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
prefix=f"{prefix}.w2",
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@ -484,17 +484,12 @@ class Lfm2ForCausalLM(
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Lfm2 currently does not support prefix caching"
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
self.model = Lfm2Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
797
vllm/model_executor/models/lfm2_moe.py
Normal file
797
vllm/model_executor/models/lfm2_moe.py
Normal file
@ -0,0 +1,797 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.short_conv import ShortConv
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Lfm2MoeConfig
|
||||
|
||||
from .interfaces import (
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
MixtureOfExperts,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class Lfm2MoeMlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ff_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.w1 = MergedColumnParallelLinear(
|
||||
input_size=dim,
|
||||
output_sizes=[ff_dim] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w1",
|
||||
)
|
||||
self.w2 = RowParallelLinear(
|
||||
input_size=ff_dim,
|
||||
output_size=dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w2",
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.w1(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.w2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Lfm2MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Lfm2MoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
if self.tp_size > self.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.n_routed_experts}."
|
||||
)
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||
self.physical_expert_end = (
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
if config.use_expert_bias:
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(self.n_routed_experts, dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True, # needed for softmax score func
|
||||
num_expert_group=1,
|
||||
topk_group=1,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
scoring_func="sigmoid",
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = (
|
||||
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
||||
* self.routed_scaling_factor
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states
|
||||
)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class Lfm2MoeAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Lfm2MoeConfig,
|
||||
layer_idx: int,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = hidden_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)
|
||||
self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
n_tokens, _ = hidden_states.shape
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous()
|
||||
k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous()
|
||||
q = self.q_layernorm(q)
|
||||
k = self.k_layernorm(k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q = q.view(n_tokens, self.num_heads * self.head_dim)
|
||||
k = k.view(n_tokens, self.num_kv_heads * self.head_dim)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Lfm2MoeAttentionDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Lfm2MoeConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
|
||||
self.self_attn = Lfm2MoeAttention(
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
if layer_idx < config.num_dense_layers:
|
||||
self.feed_forward = Lfm2MoeMlp(
|
||||
dim=config.hidden_size,
|
||||
ff_dim=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
self.feed_forward = Lfm2MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
|
||||
self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.operator_norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.operator_norm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
|
||||
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
||||
return self.feed_forward(hidden_states), residual
|
||||
|
||||
|
||||
class Lfm2MoeShortConvDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Lfm2MoeConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.conv = ShortConv(
|
||||
config=config,
|
||||
dim=config.hidden_size,
|
||||
layer_idx=layer_idx,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.conv",
|
||||
)
|
||||
|
||||
if layer_idx < config.num_dense_layers:
|
||||
self.feed_forward = Lfm2MoeMlp(
|
||||
dim=config.hidden_size,
|
||||
ff_dim=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
self.feed_forward = Lfm2MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
|
||||
self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.operator_norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.operator_norm(hidden_states, residual)
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.conv(
|
||||
hidden_states,
|
||||
output,
|
||||
)
|
||||
hidden_states, residual = self.ffn_norm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Lfm2MoeModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
self.config = config
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size
|
||||
)
|
||||
|
||||
def get_layer(prefix: str):
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
is_attn = self.config.layer_types[layer_idx] == "full_attention"
|
||||
layer_class = (
|
||||
Lfm2MoeAttentionDecoderLayer
|
||||
if is_attn
|
||||
else Lfm2MoeShortConvDecoderLayer
|
||||
)
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
|
||||
)
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
else:
|
||||
self.embedding_norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers[self.start_layer : self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.embedding_norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
num_experts=self.config.num_experts,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".w1", ".w1", 0),
|
||||
(".w1", ".w3", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if "expert_bias" in name:
|
||||
name = name.replace("expert_bias", "gate.e_score_correction_bias")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if ("feed_forward.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Lfm2MoeForCausalLM(
|
||||
nn.Module,
|
||||
HasInnerState,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
IsHybrid,
|
||||
SupportsQuant,
|
||||
MixtureOfExperts,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"w1": [
|
||||
"w1",
|
||||
"w3",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return MambaStateDtypeCalculator.short_conv_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[tuple[int, int]]:
|
||||
"""Calculate shapes for LFM2's convolutional cache.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- conv_state_shape: Shape for convolutional state cache
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
return MambaStateShapeCalculator.short_conv_state_shape(
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
intermediate_size=hf_config.hidden_size,
|
||||
conv_kernel=hf_config.conv_L_cache,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Lfm2Moe currently does not support prefix caching"
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Lfm2MoeModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = self.config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size
|
||||
),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
layer, (Lfm2MoeAttentionDecoderLayer, Lfm2MoeShortConvDecoderLayer)
|
||||
)
|
||||
if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock):
|
||||
example_layer = layer.feed_forward
|
||||
self.moe_layers.append(layer.feed_forward.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError(
|
||||
"No Lfm2MoeSparseMoeBlock layer found in the model.layers."
|
||||
)
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
self.num_logical_experts = example_layer.n_logical_experts
|
||||
self.num_physical_experts = example_layer.n_physical_experts
|
||||
self.num_local_physical_experts = example_layer.n_local_physical_experts
|
||||
self.num_routed_experts = example_layer.n_routed_experts
|
||||
self.num_redundant_experts = example_layer.n_redundant_experts
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock):
|
||||
moe = layer.feed_forward
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
@ -119,6 +119,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
|
||||
"Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
|
||||
@ -91,6 +91,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
step3_vl="Step3VLConfig",
|
||||
step3_text="Step3TextConfig",
|
||||
qwen3_next="Qwen3NextConfig",
|
||||
lfm2_moe="Lfm2MoeConfig",
|
||||
)
|
||||
|
||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
from vllm.transformers_utils.configs.midashenglm import MiDashengLMConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
@ -46,6 +47,7 @@ __all__ = [
|
||||
"EAGLEConfig",
|
||||
"RWConfig",
|
||||
"JAISConfig",
|
||||
"Lfm2MoeConfig",
|
||||
"MedusaConfig",
|
||||
"MiDashengLMConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
|
||||
160
vllm/transformers_utils/configs/lfm2_moe.py
Normal file
160
vllm/transformers_utils/configs/lfm2_moe.py
Normal file
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Lfm2MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Lfm2MoeModel`]. It is used to instantiate a LFM2 Moe
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the LFM2-8B-A1B model.
|
||||
e.g. [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 65536):
|
||||
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Lfm2Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 7168):
|
||||
Dimension of the MLP representations.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1792):
|
||||
Intermediate size of the routed expert.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
conv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in the conv layers.
|
||||
conv_L_cache (`int`, *optional*, defaults to 3):
|
||||
L_cache dim in the conv layers.
|
||||
num_dense_layers (`int`, *optional*, defaults to 2):
|
||||
Number of dense Lfm2MoeMLP layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 4):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 32):
|
||||
Number of routed experts.
|
||||
use_expert_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the expert bias on the routing weights.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for routed experts in MoE models.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the topk probabilities.
|
||||
layer_types (`Optional`, *optional*):
|
||||
Type of each layers.
|
||||
|
||||
```python
|
||||
>>> from transformers import Lfm2MoeModel, Lfm2MoeConfig
|
||||
|
||||
>>> # Initializing a LFM2 Moe model
|
||||
>>> configuration = Lfm2MoeConfig()
|
||||
|
||||
>>> # Initializing a model from the LFM2-8B-A1B style configuration
|
||||
>>> model = Lfm2MoeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```""" # noqa: E501
|
||||
|
||||
model_type = "lfm2_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 65536,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 7168,
|
||||
moe_intermediate_size: int = 1792,
|
||||
num_hidden_layers: int = 32,
|
||||
pad_token_id: int = 0,
|
||||
bos_token_id: int = 1,
|
||||
eos_token_id: int = 2,
|
||||
tie_word_embeddings: bool = True,
|
||||
rope_theta: float = 1000000.0,
|
||||
max_position_embeddings: int = 128_000,
|
||||
use_cache: bool = True,
|
||||
norm_eps: float = 0.00001,
|
||||
num_attention_heads: int = 32,
|
||||
num_key_value_heads: int = 8,
|
||||
conv_bias: bool = False,
|
||||
conv_L_cache: int = 3,
|
||||
num_dense_layers: int = 2,
|
||||
num_experts_per_tok: int = 4,
|
||||
num_experts: int = 32,
|
||||
use_expert_bias: bool = True,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
norm_topk_prob: bool = True,
|
||||
layer_types: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.use_cache = use_cache
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# attn operator config
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
# custom operator config
|
||||
self.conv_bias = conv_bias
|
||||
self.conv_L_cache = conv_L_cache
|
||||
|
||||
# moe config
|
||||
self.num_dense_layers = num_dense_layers
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.use_expert_bias = use_expert_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.layer_types = layer_types
|
||||
|
||||
tie_word_embeddings = kwargs.get(
|
||||
"tie_embedding", tie_word_embeddings
|
||||
) # to fit original config keys
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Lfm2MoeConfig"]
|
||||
Loading…
x
Reference in New Issue
Block a user