diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 3ffbf63f9a18b..62e42a730e9cb 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -415,6 +415,7 @@ th { | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | | `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | +| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | ︎| ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | | `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index fa70e94abd865..82b9303b2a21b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -459,6 +459,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ), "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), + "MiMoV2FlashForCausalLM": _HfExamplesInfo( + "XiaomiMiMo/MiMo-V2-Flash", trust_remote_code=True + ), "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), } diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 0e91dd57420a8..3c77fad41d077 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -18,6 +18,7 @@ from vllm.config.lora import LoRAConfig from vllm.config.model import ( ModelConfig, iter_architecture_defaults, + str_dtype_to_torch_dtype, try_match_architecture_defaults, ) from vllm.config.multimodal import MultiModalConfig @@ -72,6 +73,7 @@ __all__ = [ # From vllm.config.model "ModelConfig", "iter_architecture_defaults", + "str_dtype_to_torch_dtype", "try_match_architecture_defaults", # From vllm.config.multimodal "MultiModalConfig", diff --git a/vllm/config/model.py b/vllm/config/model.py index 1de9d15cf8c52..db5789b709372 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1849,6 +1849,11 @@ _STR_DTYPE_TO_TORCH_DTYPE = { "bfloat16": torch.bfloat16, } + +def str_dtype_to_torch_dtype(type: str): + return _STR_DTYPE_TO_TORCH_DTYPE.get(type) + + # model_type -> reason _FLOAT16_NOT_SUPPORTED_MODELS = { "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index dfcc601a1c530..4ca4f75711ac7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -277,6 +277,7 @@ class LinearBase(CustomOp): self.params_dtype = params_dtype self.quant_config = quant_config self.prefix = prefix + self.allow_fp8_block_shape_mismatch = False if quant_config is None: self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod() else: @@ -475,6 +476,7 @@ class ColumnParallelLinear(LinearBase): disable_tp=disable_tp, ) + self._maybe_allow_fp8_block_shape_mismatch() self.gather_output = gather_output if output_sizes is None: @@ -509,6 +511,33 @@ class ColumnParallelLinear(LinearBase): self.register_parameter("bias", None) self.update_param_tp_status() + def _maybe_allow_fp8_block_shape_mismatch(self) -> None: + quant_config = getattr(self, "quant_config", None) + weight_block = getattr(quant_config, "weight_block_size", None) + if ( + weight_block is None + or len(weight_block) < 1 + or len(self.output_partition_sizes) <= 1 + ): + return + + try: + block_n = int(weight_block[0]) + except (ValueError, TypeError): + return + + if block_n <= 0: + return + + if any(size % block_n != 0 for size in self.output_partition_sizes): + self.allow_fp8_block_shape_mismatch = True + logger.debug( + "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)", + getattr(self, "prefix", ""), + block_n, + self.output_partition_sizes, + ) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): output_dim = getattr(param, "output_dim", None) @@ -906,9 +935,11 @@ class QKVParallelLinear(ColumnParallelLinear): *, return_bias: bool = True, disable_tp: bool = False, + v_head_size: int | None = None, ): self.hidden_size = hidden_size self.head_size = head_size + self.v_head_size = v_head_size if v_head_size is not None else head_size self.total_num_heads = total_num_heads if total_num_kv_heads is None: total_num_kv_heads = total_num_heads @@ -924,12 +955,14 @@ class QKVParallelLinear(ColumnParallelLinear): self.num_kv_head_replicas = 1 input_size = self.hidden_size output_size = ( - (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - ) + self.num_heads * self.head_size + + self.num_kv_heads * self.head_size + + self.num_kv_heads * self.v_head_size + ) * tp_size self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj + self.num_kv_heads * self.v_head_size * tp_size, # v_proj ] super().__init__( @@ -950,7 +983,8 @@ class QKVParallelLinear(ColumnParallelLinear): "q": 0, "k": self.num_heads * self.head_size, "v": (self.num_heads + self.num_kv_heads) * self.head_size, - "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + "total": (self.num_heads + self.num_kv_heads) * self.head_size + + self.num_kv_heads * self.v_head_size, } return shard_offset_mapping.get(loaded_shard_id) @@ -958,7 +992,7 @@ class QKVParallelLinear(ColumnParallelLinear): shard_size_mapping = { "q": self.num_heads * self.head_size, "k": self.num_kv_heads * self.head_size, - "v": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.v_head_size, } return shard_size_mapping.get(loaded_shard_id) @@ -985,7 +1019,7 @@ class QKVParallelLinear(ColumnParallelLinear): ( "v", (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size, + self.total_num_kv_heads * self.v_head_size, ), ] @@ -1110,7 +1144,7 @@ class QKVParallelLinear(ColumnParallelLinear): ( "v", (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size, + self.total_num_kv_heads * self.v_head_size, ), ] use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) @@ -1139,11 +1173,12 @@ class QKVParallelLinear(ColumnParallelLinear): "v": ( (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size, + self.total_num_kv_heads * self.v_head_size, ), "total": ( - (self.total_num_heads + 2 * self.total_num_kv_heads) - * self.head_size, + (self.total_num_heads + self.total_num_kv_heads) + * self.head_size + + self.total_num_kv_heads * self.v_head_size, 0, ), } @@ -1170,7 +1205,7 @@ class QKVParallelLinear(ColumnParallelLinear): shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size - shard_size = self.num_kv_heads * self.head_size + shard_size = self.num_kv_heads * self.v_head_size # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -1199,10 +1234,11 @@ class QKVParallelLinear(ColumnParallelLinear): ), "v": ( (self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size, + self.num_kv_heads * self.v_head_size, ), "total": ( - (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + (self.num_heads + self.num_kv_heads) * self.head_size + + self.num_kv_heads * self.v_head_size, 0, ), } diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index bdc3d1fc7232d..13e813952b30a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1252,6 +1252,14 @@ def validate_fp8_block_shape( """Validate block quantization shapes for tensor parallelism.""" from vllm.distributed import get_tensor_model_parallel_world_size + if getattr(layer, "allow_fp8_block_shape_mismatch", False): + logger.debug( + "Skipping FP8 block shape validation for layer %s due to detected" + " mismatch allowance.", + getattr(layer, "prefix", ""), + ) + return + tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size()) block_n, block_k = block_size[0], block_size[1] diff --git a/vllm/model_executor/models/mimo_v2_flash.py b/vllm/model_executor/models/mimo_v2_flash.py new file mode 100644 index 0000000000000..12b486f001e03 --- /dev/null +++ b/vllm/model_executor/models/mimo_v2_flash.py @@ -0,0 +1,720 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from itertools import islice + +import torch +from torch import nn + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention +from vllm.config import ( + CacheConfig, + VllmConfig, + get_current_vllm_config, + str_dtype_to_torch_dtype, +) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.sequence import IntermediateTensors + +from .interfaces import MixtureOfExperts, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class MiMoV2MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiMoV2MoE(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + is_nextn: bool = False, + ): + super().__init__() + + config = vllm_config.model_config.hf_text_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.n_routed_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}." + ) + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + dtype = getattr(config, "moe_router_dtype", "float32") + self.gate_dtype = str_dtype_to_torch_dtype(dtype) + self.gate = nn.Linear( + config.hidden_size, + config.n_routed_experts, + bias=False, + dtype=self.gate_dtype, + ) + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=self.gate_dtype) + ) + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=True, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + scoring_func="sigmoid", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert hidden_states.dim() <= 2, "MiMoV2MoE only supports 1D or 2D inputs" + is_input_1d = hidden_states.dim() == 1 + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + if self.gate_dtype is not None: + gate_input = hidden_states.to(self.gate_dtype) + else: + gate_input = hidden_states + router_logits = self.gate(gate_input) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + + return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states + + +class MiMoV2Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + v_head_dim: int | None = None, + sliding_window_size: int = -1, + attention_bias: bool = False, + add_swa_attention_sink_bias: bool = False, + layer_id: int = 0, + rope_theta: float = 1000000, + max_position_embeddings: int = 32768, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + partial_rotary_factor: float = 1.0, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.layer_id = layer_id + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_heads + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.head_dim = head_dim + + self.v_head_dim = v_head_dim if v_head_dim is not None else head_dim + + self.q_size = self.num_heads * self.head_dim + self.k_size = self.num_kv_heads * self.head_dim + self.v_size = self.num_kv_heads * self.v_head_dim + + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + v_head_size=self.v_head_dim, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.v_head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + max_position=max_position_embeddings, + rope_parameters={ + "rope_type": "default", + "rope_theta": rope_theta, + "partial_rotary_factor": partial_rotary_factor, + }, + ) + + self.attention_sink_bias = ( + torch.nn.Parameter(torch.empty(self.num_heads), requires_grad=False) + if add_swa_attention_sink_bias + else None + ) + + sliding_window = sliding_window_size if sliding_window_size > -1 else None + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + sinks=self.attention_sink_bias, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + v = v.view(-1, self.num_kv_heads, self.v_head_dim) + v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0) + v = v.view(-1, self.num_kv_heads * self.head_dim) + + attn_output = self.attn(q, k, v) + + attn_output = attn_output.view(-1, self.num_heads, self.head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + +class MiMoV2FlashDecoderLayer(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_text_config + quant_config = vllm_config.quant_config + layer_id = extract_layer_index(prefix) + + self.hidden_size = config.hidden_size + self.config = config + self.layer_id = layer_id + + rope_theta = getattr(config, "rope_theta", 1000000) + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + + if self.is_compressed_softmax_layer(): + self.self_attn = MiMoV2Attention( + hidden_size=self.hidden_size, + num_heads=config.swa_num_attention_heads, + num_kv_heads=config.swa_num_key_value_heads, + head_dim=config.swa_head_dim, + v_head_dim=getattr(config, "swa_v_head_dim", None), + sliding_window_size=config.sliding_window_size, + attention_bias=config.attention_bias, + add_swa_attention_sink_bias=getattr( + config, "add_swa_attention_sink_bias", False + ), + layer_id=layer_id, + rope_theta=getattr(config, "swa_rope_theta", rope_theta), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0), + prefix=f"{prefix}.self_attn", + ) + else: + self.self_attn = MiMoV2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + v_head_dim=getattr(config, "v_head_dim", None), + sliding_window_size=-1, # normal attention + attention_bias=config.attention_bias, + layer_id=layer_id, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0), + prefix=f"{prefix}.self_attn", + ) + + self.is_layer_sparse = self.is_moe_layer(layer_id) + if self.is_layer_sparse: + self.mlp = MiMoV2MoE( + vllm_config=vllm_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = MiMoV2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.layernorm_epsilon + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + def is_moe_layer(self, layer_idx: int) -> bool: + return ( + hasattr(self.config, "moe_layer_freq") + and layer_idx >= 0 + and not isinstance(self.config.moe_layer_freq, int) + and self.config.moe_layer_freq[layer_idx] + ) + + def is_compressed_softmax_layer(self) -> bool: + return self.config.hybrid_layer_pattern[self.layer_id] == 1 + + +class MiMoV2Model(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config.get_text_config() + quant_config = vllm_config.quant_config + eplb_config = vllm_config.parallel_config.eplb_config + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.num_redundant_experts = eplb_config.num_redundant_experts + self.v_scale = getattr(config, "attention_value_scale", None) + + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiMoV2FlashDecoderLayer( + vllm_config=vllm_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon) + else: + self.norm = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + continue + if "mtp" in name: + continue + + if self.quant_config is not None: + cache_scale_name = self.quant_config.get_cache_scale(name) + if cache_scale_name is not None and cache_scale_name in params_dict: + param = params_dict[cache_scale_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + + kv_scale = loaded_weight + if kv_scale.dim() > 0 and kv_scale.numel() > 1: + kv_scale = kv_scale.view(-1)[0] + + weight_loader(param, kv_scale) + loaded_params.add(cache_scale_name) + continue + + expert_matched = False + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: + if weight_name not in name: + continue + + name_rewritten = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_rewritten, self): + continue + + if ( + name_rewritten.endswith(".bias") or name_rewritten.endswith("_bias") + ) and name_rewritten not in params_dict: + continue + + if name_rewritten not in params_dict: + continue + + param = params_dict[name_rewritten] + weight_loader = param.weight_loader + + weight_loader( + param, + loaded_weight, + name_rewritten, + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(name_rewritten) + expert_matched = True + break + + if expert_matched: + continue + + stacked_matched = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name_rewritten = name.replace(weight_name, param_name) + + if ( + name_rewritten.endswith(".bias") + and name_rewritten not in params_dict + ): + continue + + if is_pp_missing_parameter(name_rewritten, self): + continue + + if name_rewritten not in params_dict: + continue + + param = params_dict[name_rewritten] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + + if param_name == "qkv_proj" and shard_id == "v": + v_scale = ( + self.v_scale + if self.v_scale is not None + else getattr(self.config, "attention_value_scale", None) + ) + if v_scale is not None and ( + name.endswith("weight_scale_inv") or name.endswith(".bias") + ): + loaded_weight *= float(v_scale) + + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name_rewritten) + + stacked_matched = True + break + + if stacked_matched: + continue + + if name.endswith(".bias") and name not in params_dict: + continue + + orig_name = name + mapped_name = maybe_remap_kv_scale_name(name, params_dict) + name = mapped_name if mapped_name is not None else orig_name + + if name not in params_dict: + continue + + param = params_dict[name] + + if "attention_sink_bias" in name: + total_heads = loaded_weight.shape[0] + heads_per_rank = total_heads // tp_size + head_start = tp_rank * heads_per_rank + narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank) + + param.data.copy_(narrow_weight) + loaded_params.add(name) + else: + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.model = MiMoV2Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d332f51152484..3ba61b52cfdf1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -152,6 +152,7 @@ _TEXT_GENERATION_MODELS = { "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"), + "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),