From d8c6210eeaa7f3b474e50cf74926f77a8dc79adf Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:29:00 +0100 Subject: [PATCH] Add Mistral Large 3 and Ministral 3 (#29757) Signed-off-by: Julien Denize Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Signed-off-by: Mickael Seznec Signed-off-by: Roger Wang Co-authored-by: Roger Wang Co-authored-by: Mickael Seznec --- docs/models/supported_models.md | 5 +- tests/models/registry.py | 14 ++ tests/tokenizers_/test_mistral.py | 158 ++++++++++++++++- vllm/config/speculative.py | 4 + .../tool_parsers/mistral_tool_parser.py | 2 +- ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++++++++++++++++ vllm/model_executor/layers/mla.py | 4 + .../layers/rotary_embedding/__init__.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 66 ++++++- vllm/model_executor/models/mistral_large_3.py | 63 +++++++ .../models/mistral_large_3_eagle.py | 165 ++++++++++++++++++ vllm/model_executor/models/registry.py | 5 + vllm/tokenizers/mistral.py | 36 ++++ vllm/transformers_utils/configs/eagle.py | 6 + vllm/transformers_utils/configs/mistral.py | 74 ++++++-- vllm/v1/spec_decode/eagle.py | 4 + 16 files changed, 724 insertions(+), 30 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/models/mistral_large_3.py create mode 100644 vllm/model_executor/models/mistral_large_3_eagle.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index da7c5edf66bfb..6ea2285b92bb8 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -417,7 +417,8 @@ th { | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | | `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ | -| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | +| `MistralForCausalLM` | Ministral-3, Mistral, Mistral-Instruct | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | +| `MistralLarge3ForCausalLM` | Mistral-Large-3-675B-Base-2512, Mistral-Large-3-675B-Instruct-2512 | `mistralai/Mistral-Large-3-675B-Base-2512`, `mistralai/Mistral-Large-3-675B-Instruct-2512`, etc. | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | @@ -711,7 +712,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | | `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | -| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | +| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ | | `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index d90f3a4d4f781..26351089fc464 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -358,6 +358,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True, ), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), + "MistralLarge3ForCausalLM": _HfExamplesInfo( + "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False + ), "MixtralForCausalLM": _HfExamplesInfo( "mistralai/Mixtral-8x7B-Instruct-v0.1", {"tiny": "TitanML/tiny-mixtral"}, @@ -770,7 +773,13 @@ _MULTIMODAL_EXAMPLE_MODELS = { ), "PixtralForConditionalGeneration": _HfExamplesInfo( "mistralai/Pixtral-12B-2409", + extras={ + "mistral-large-3": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", + "ministral-3": "mistralai/Ministral-3-3B-Instruct-2512", + }, tokenizer_mode="mistral", + # TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available. + is_available_online=False, ), "QwenVLForConditionalGeneration": _HfExamplesInfo( "Qwen/Qwen-VL", @@ -870,6 +879,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { use_original_num_layers=True, max_model_len=10240, ), + "EagleMistralLarge3ForCausalLM": _HfExamplesInfo( + "mistralai/Mistral-Large-3-675B-Instruct-2512", + speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle", + is_available_online=False, + ), "LlamaForCausalLMEagle3": _HfExamplesInfo( "Qwen/Qwen3-8B", trust_remote_code=True, diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py index 92efac86dff29..faff611502652 100644 --- a/tests/tokenizers_/test_mistral.py +++ b/tests/tokenizers_/test_mistral.py @@ -91,6 +91,118 @@ from vllm.tokenizers.mistral import ( ], ), ), + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "unsupported_field": False, + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "unsupported_field2": False, + "name": "get_current_time", + "parameters": {}, + }, + }, + ], + }, + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + ], + ), + ), + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "unsupported_field": False, + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "unsupported_field2": False, + "function": { + "description": "Fetch the current local date and time 2.", + "name": "get_current_time2", + "parameters": {"a": "1"}, + }, + }, + ], + }, + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "description": "Fetch the current local date and time 2.", + "name": "get_current_time2", + "parameters": {"a": "1"}, + }, + }, + ], + ), + ), ], ) def test_prepare_apply_chat_template_tools_and_messages( @@ -1108,13 +1220,6 @@ class TestMistralTokenizer: ) == expected_tokens[mistral_tokenizer.is_tekken] ) - assert ( - mistral_tokenizer.decode( - ids[mistral_tokenizer.is_tekken], - skip_special_tokens=skip_special_tokens, - ) - == expected_tokens[mistral_tokenizer.is_tekken] - ) def test_decode_empty( self, @@ -1140,6 +1245,45 @@ class TestMistralTokenizer: == "" ) + @pytest.mark.parametrize( + "skip_special_tokens,expected_tokens", + ( + ( + False, + ( + ["[INST]▁Hello▁world▁![/INST]▁Hello"], + ["[INST]Hello world ![/INST]Hello"], + ), + ), + (True, (["Hello world ! Hello"], ["Hello world !Hello"])), + ), + ) + def test_batch_decode( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + expected_tokens: tuple[str, str], + ): + ids = ( + [[1, 3, 23325, 2294, 1686, 4, 23325, 2]], + [[1, 3, 22177, 4304, 2662, 4, 22177, 2]], + ) + assert ( + mistral_tokenizer.batch_decode( + ids[mistral_tokenizer.is_tekken], + skip_special_tokens=skip_special_tokens, + ) + == expected_tokens[mistral_tokenizer.is_tekken] + ) + + def test_batch_decode_empty( + self, + mistral_tokenizer: MistralTokenizer, + ): + assert mistral_tokenizer.batch_decode( + [[]], + ) == [""] + def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer): tokens = ( [ diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 80d53a543f149..c6d6f705f535c 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -167,6 +167,7 @@ class SpeculativeConfig: @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + initial_architecture = hf_config.architectures[0] if hf_config.model_type in ("deepseek_v3", "deepseek_v32"): hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": @@ -226,6 +227,9 @@ class SpeculativeConfig: {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} ) + if initial_architecture == "MistralLarge3ForCausalLM": + hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) + return hf_config def __post_init__(self): diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 7e2d67a1fb659..89b882d6c8475 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -80,7 +80,7 @@ class MistralToolParser(ToolParser): self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) if _is_fn_name_regex_support(self.model_tokenizer): self.fn_name_regex = re.compile( - r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL + r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)?", re.DOTALL ) else: self.fn_name_regex = None diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..a9f24c20a25a2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 6ebfa47a9dc3f..dad960160f2ad 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -111,6 +111,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp): self, positions: torch.Tensor, hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: q_c = None kv_lora = None @@ -159,6 +160,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp): hidden_states, q_c, positions, self.indexer_rope_emb ) + if llama_4_scaling is not None: + q *= llama_4_scaling + attn_out = self.mla_attn( q, kv_c_normed, diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 0f10bff6ac4f5..aa6ece30026d3 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -238,7 +238,7 @@ def get_rope( dtype, **extra_kwargs, ) - elif scaling_type == "deepseek_yarn": + elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]: scaling_factor = rope_parameters["factor"] original_max_position = rope_parameters["original_max_position_embeddings"] # assert max_position == original_max_position * scaling_factor diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 73cac2556c55a..d8a081af125c5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 +def _get_llama_4_scaling( + original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor +) -> torch.Tensor: + scaling = 1 + scaling_beta * torch.log( + 1 + torch.floor(positions / original_max_position_embeddings) + ) + # Broadcast over num_heads and head_dim + return scaling[..., None, None] + + class DeepseekV2Attention(nn.Module): def __init__( self, @@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module): prefix=f"{prefix}.o_proj", ) if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = "deepseek_yarn" + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) self.rotary_emb = get_rope( qk_rope_head_dim, @@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module): is_neox_style=False, ) - if config.rope_parameters["rope_type"] != "default": + if ( + config.rope_parameters["rope_type"] != "default" + and config.rope_parameters["rope_type"] == "deepseek_yarn" + ): mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) scaling_factor = config.rope_parameters["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) @@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] @@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module): k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim :] = k_pe + + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + # padding value to qk_head_dim for alignment v = torch.nn.functional.pad( v, [0, self.qk_head_dim - self.v_head_dim], value=0 @@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module): ) if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = "deepseek_yarn" + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, @@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module): rope_parameters=config.rope_parameters, is_neox_style=False, ) - if config.rope_parameters["rope_type"] != "default": + + if ( + config.rope_parameters["rope_type"] != "default" + and config.rope_parameters["rope_type"] == "deepseek_yarn" + ): mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) scaling_factor = config.rope_parameters["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) @@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, ) -> torch.Tensor: - return self.mla_attn(positions, hidden_states) + return self.mla_attn(positions, hidden_states, llama_4_scaling) class DeepseekV2DecoderLayer(nn.Module): @@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + llama_4_scaling=llama_4_scaling, ) if ( @@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + # Compute llama 4 scaling once per forward pass if enabled + llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None) + llama_4_scaling: torch.Tensor | None + if llama_4_scaling_config is not None: + llama_4_scaling = _get_llama_4_scaling( + original_max_position_embeddings=llama_4_scaling_config[ + "original_max_position_embeddings" + ], + scaling_beta=llama_4_scaling_config["beta"], + positions=positions, + ) + else: + llama_4_scaling = None + for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM( packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } + model_cls = DeepseekV2Model def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM( "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model( + self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/mistral_large_3.py b/vllm/model_executor/models/mistral_large_3.py new file mode 100644 index 0000000000000..ff7e9b60c1d3c --- /dev/null +++ b/vllm/model_executor/models/mistral_large_3.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import regex as re +import torch + +from vllm.model_executor.models.deepseek_v2 import DeepseekV3ForCausalLM + + +class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM): + # fmt: off + remapping = { + r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501 + r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501 + r"norm\.weight": "model.norm.weight", # noqa: E501 + r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501 + r"output\.weight": "lm_head.weight", # noqa: E501 + } + # fmt: on + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return super().load_weights(map(self._remap_mistral_to_ds, weights)) + + def _remap_mistral_to_ds( + self, weight: tuple[str, torch.Tensor] + ) -> tuple[str, torch.Tensor]: + """Remap Mistral parameters to DeepseekV2 parameters.""" + name, loaded_weight = weight + + for k, v in self.remapping.items(): + match = re.fullmatch(k, name) + if match: + name = re.sub(k, v, name) + break + else: + raise ValueError(f"Cannot remap {name}") + + # Remapping scale names. We could do this in the regex above but it + # would triple the number of lines for most layers. + if name.endswith(".qscale_act"): + name = re.sub(r"\.qscale_act$", ".input_scale", name) + elif name.endswith(".qscale_weight"): + name = re.sub(r"\.qscale_weight$", ".weight_scale", name) + + return name, loaded_weight diff --git a/vllm/model_executor/models/mistral_large_3_eagle.py b/vllm/model_executor/models/mistral_large_3_eagle.py new file mode 100644 index 0000000000000..e3ca9e4ca82d0 --- /dev/null +++ b/vllm/model_executor/models/mistral_large_3_eagle.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from functools import partial + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2DecoderLayer, + DeepseekV2Model, +) +from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM +from vllm.multimodal.inputs import NestedTensors + +from .utils import ( + _merge_multimodal_embeddings, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) + +logger = init_logger(__name__) + + +@support_torch_compile +class EagleMistralLarge3Model(DeepseekV2Model): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0 + ): + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.vllm_config = vllm_config + + self.vocab_size = config.vocab_size + + assert get_pp_group().world_size == 1 + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.start_layer = 0 + self.end_layer = self.config.num_hidden_layers + + self.fc = RowParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + bias=False, + input_is_parallel=False, + quant_config=quant_config, + return_bias=False, + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_input_ids(input_ids) + inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) + output = super().forward( + input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds + ) + assert isinstance(output, torch.Tensor) + return output + + +class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM): + remapping = MistralLarge3ForCausalLM.remapping | { + r"eagle_linear\.weight": r"model.fc.weight", + r"eagle_linear\.qscale_act": r"model.fc.input_scale", + r"eagle_linear\.qscale_weight": r"model.fc.weight_scale", + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + vllm_config.model_config = vllm_config.speculative_config.draft_model_config + # draft model quantization config may differ from target model + self.quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, vllm_config.load_config + ) + vllm_config.quant_config = self.quant_config + self.model_cls = partial( + EagleMistralLarge3Model, start_layer_id=target_layer_num + ) + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = super().embed_input_ids(input_ids) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + assert is_multimodal is not None + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Pretend we've loaded the embedding and lm_head weights + # (later copied from target model) + return super().load_weights(weights) | { + "model.embed_tokens.weight", + "lm_head.weight", + } + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: NestedTensors | None = None, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 73a61f1148b50..d3b6268e7647b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -145,6 +145,7 @@ _TEXT_GENERATION_MODELS = { "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), @@ -424,6 +425,10 @@ _SPECULATIVE_DECODING_MODELS = { "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "EagleMistralLarge3ForCausalLM": ( + "mistral_large_3_eagle", + "EagleMistralLarge3ForCausalLM", + ), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index 7e6745004b01f..96d1e78ce9f17 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -97,6 +97,8 @@ def _prepare_apply_chat_template_tools_and_messages( continue_final_message: bool = False, add_generation_prompt: bool = False, ) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]: + from mistral_common.protocol.instruct.tool_calls import Function, Tool + if add_generation_prompt and continue_final_message: raise ValueError( "Cannot set both `add_generation_prompt` and " @@ -139,6 +141,33 @@ def _prepare_apply_chat_template_tools_and_messages( if function.get("description") is None: function["description"] = "" + # We filter not supported arguments to avoid throwing an error. + # TODO(juliendenize): remove this once OpenAI API is better supported by + # `mistral-common`. + tools_fields = set(Tool.model_fields.keys()) + function_fields = set(Function.model_fields.keys()) + for tool in tools: + tool_keys = list(tool.keys()) + for tool_key in tool_keys: + if tool_key not in tools_fields: + tool.pop(tool_key) + logger.warning_once( + f"'{tool_key}' is not supported by mistral-common for tools. " + "It has been poped from the tool definition." + ) + if tool["type"] == "function": + function_keys = list(tool["function"].keys()) + for function_key in function_keys: + if function_key not in function_fields: + tool["function"].pop(function_key) + logger.warning_once( + f"'{function_key}' is not supported by mistral-common " + "for function tools. It has been poped from the " + "function definition." + ) + else: + raise ValueError("mistral-common only supports function tools.") + return messages, tools @@ -410,6 +439,13 @@ class MistralTokenizer(TokenizerLike): ids, skip_special_tokens=skip_special_tokens ) + def batch_decode( + self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False + ) -> str: + return self.transformers_tokenizer.batch_decode( + ids, skip_special_tokens=skip_special_tokens + ) + def convert_tokens_to_string(self, tokens: list[str]) -> str: from mistral_common.tokens.tokenizers.base import ( SpecialTokenPolicy, diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index f5dc9ddfbc575..ce428e567c844 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -82,3 +82,9 @@ class EAGLEConfig(PretrainedConfig): pretrained_model_name_or_path, **kwargs ) return cls.from_dict(config_dict, **kwargs) + + def to_json_string(self, use_diff: bool = True) -> str: + # we override use_diff to False as initializing + # EAGLEConfig with default arguments is not supported + del use_diff + return super().to_json_string(use_diff=False) diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 966737aad0867..d59169d95f0c9 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -18,9 +18,31 @@ def adapt_config_dict( if bool(config_dict.get("quantization")): config_dict = _remap_mistral_quantization_args(config_dict) + is_moe = bool(config_dict.get("moe")) + is_mistral_large_3 = ( + is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 + ) if config_dict.get("model_type") == "mamba": config_dict["architectures"] = ["Mamba2ForCausalLM"] - elif bool(config_dict.get("moe")): + elif is_moe and is_mistral_large_3: + config_dict = _remap_moe_args(config_dict) + config_dict["model_type"] = "deepseek_v3" + config_dict["architectures"] = ["MistralLarge3ForCausalLM"] + + assert "llama_4_scaling" in config_dict, ( + "MistralLarge3 expect llama4 scaling config." + ) + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [ + key in config_dict["llama_4_scaling"] + for key in llama_4_scaling_config_keys + ] + ), ( + "llama_4_scaling config should define the keys: " + f"{','.join(llama_4_scaling_config_keys)}" + ) + elif is_moe: config_dict["architectures"] = ["MixtralForCausalLM"] else: config_dict["architectures"] = ["MistralForCausalLM"] @@ -140,17 +162,20 @@ def _remap_general_mistral_args(config: dict) -> dict: def _remap_mistral_quantization_args(config: dict) -> dict: - quantization = config.get("quantization", {}) - if quantization.get("qformat_weight") == "fp8_e4m3": - # This maps to the FP8 static per-tensor quantization scheme - quantization_config = {"quant_method": "fp8", "activation_scheme": "static"} - elif quantization.get("quant_method") == "compressed-tensors": - # Pass through the quantization config to compressed-tensors - quantization_config = quantization - else: - raise ValueError(f"Found unknown quantization='{quantization}' in config") - - config["quantization_config"] = quantization_config + if config.get("quantization"): + quantization = config.pop("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + qscheme_act = quantization.get("qscheme_act") + assert qscheme_act in ("NO_SCALES", "TENSOR", None), ( + "Only NO_SCALES and TENSOR (default) are supported for qscheme_act" + ) + is_dynamic = qscheme_act == "NO_SCALES" + config["quantization_config"] = { + "quant_method": "fp8", + "activation_scheme": "dynamic" if is_dynamic else "static", + } + else: + raise ValueError(f"Found unknown quantization='{quantization}' in config") return config @@ -183,3 +208,28 @@ def _remap_mistral_audio_args(config: dict) -> dict: if quant_config: config["quantization_config"] = quant_config return config + + +def _remap_moe_args(config: dict) -> dict: + moe_config_map = { + "route_every_n": "moe_layer_freq", + "first_k_dense_replace": "first_k_dense_replace", + "num_experts_per_tok": "num_experts_per_tok", + "num_experts": "n_routed_experts", + "expert_hidden_dim": "moe_intermediate_size", + "routed_scale": "routed_scaling_factor", + "num_shared_experts": "n_shared_experts", + "num_expert_groups": "n_group", + "num_expert_groups_per_tok": "topk_group", + } + moe_config = config.get("moe", {}) + for old_name, new_name in moe_config_map.items(): + if old_name in moe_config: + value = moe_config.pop(old_name) + config[new_name] = value + + config["topk_method"] = None + config["norm_topk_prob"] = True + config["scoring_func"] = "softmax" + + return config diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d7111d52dd8a1..1c7845a14b742 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1016,6 +1016,10 @@ class EagleProposer: "Qwen3VLForConditionalGeneration", ]: self.model.config.image_token_index = target_model.config.image_token_id + elif self.get_model_name(target_model) == "PixtralForConditionalGeneration": + self.model.config.image_token_index = ( + target_model.config.vision_config.image_token_id + ) else: self.model.config.image_token_index = ( target_model.config.image_token_index