diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index da7c5edf66bfb..6ea2285b92bb8 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -417,7 +417,8 @@ th {
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
-| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
+| `MistralForCausalLM` | Ministral-3, Mistral, Mistral-Instruct | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
+| `MistralLarge3ForCausalLM` | Mistral-Large-3-675B-Base-2512, Mistral-Large-3-675B-Instruct-2512 | `mistralai/Mistral-Large-3-675B-Base-2512`, `mistralai/Mistral-Large-3-675B-Instruct-2512`, etc. | ✅︎ | ✅︎ |
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ |
@@ -711,7 +712,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
-| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ |
+| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ |
| `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ |
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ |
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index d90f3a4d4f781..26351089fc464 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -358,6 +358,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True,
),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
+ "MistralLarge3ForCausalLM": _HfExamplesInfo(
+ "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False
+ ),
"MixtralForCausalLM": _HfExamplesInfo(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
{"tiny": "TitanML/tiny-mixtral"},
@@ -770,7 +773,13 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"PixtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Pixtral-12B-2409",
+ extras={
+ "mistral-large-3": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4",
+ "ministral-3": "mistralai/Ministral-3-3B-Instruct-2512",
+ },
tokenizer_mode="mistral",
+ # TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available.
+ is_available_online=False,
),
"QwenVLForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen-VL",
@@ -870,6 +879,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
use_original_num_layers=True,
max_model_len=10240,
),
+ "EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
+ "mistralai/Mistral-Large-3-675B-Instruct-2512",
+ speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
+ is_available_online=False,
+ ),
"LlamaForCausalLMEagle3": _HfExamplesInfo(
"Qwen/Qwen3-8B",
trust_remote_code=True,
diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py
index 92efac86dff29..faff611502652 100644
--- a/tests/tokenizers_/test_mistral.py
+++ b/tests/tokenizers_/test_mistral.py
@@ -91,6 +91,118 @@ from vllm.tokenizers.mistral import (
],
),
),
+ (
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "What is the current local date and time?",
+ }
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "description": "Fetch the current local date and time.",
+ "unsupported_field": False,
+ "name": "get_current_time",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "description": "Fetch the current local date and time.",
+ "unsupported_field2": False,
+ "name": "get_current_time",
+ "parameters": {},
+ },
+ },
+ ],
+ },
+ (
+ [
+ {
+ "role": "user",
+ "content": "What is the current local date and time?",
+ }
+ ],
+ [
+ {
+ "type": "function",
+ "function": {
+ "description": "Fetch the current local date and time.",
+ "name": "get_current_time",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "description": "Fetch the current local date and time.",
+ "name": "get_current_time",
+ "parameters": {},
+ },
+ },
+ ],
+ ),
+ ),
+ (
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "What is the current local date and time?",
+ }
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "unsupported_field": False,
+ "function": {
+ "description": "Fetch the current local date and time.",
+ "name": "get_current_time",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "unsupported_field2": False,
+ "function": {
+ "description": "Fetch the current local date and time 2.",
+ "name": "get_current_time2",
+ "parameters": {"a": "1"},
+ },
+ },
+ ],
+ },
+ (
+ [
+ {
+ "role": "user",
+ "content": "What is the current local date and time?",
+ }
+ ],
+ [
+ {
+ "type": "function",
+ "function": {
+ "description": "Fetch the current local date and time.",
+ "name": "get_current_time",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "description": "Fetch the current local date and time 2.",
+ "name": "get_current_time2",
+ "parameters": {"a": "1"},
+ },
+ },
+ ],
+ ),
+ ),
],
)
def test_prepare_apply_chat_template_tools_and_messages(
@@ -1108,13 +1220,6 @@ class TestMistralTokenizer:
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
- assert (
- mistral_tokenizer.decode(
- ids[mistral_tokenizer.is_tekken],
- skip_special_tokens=skip_special_tokens,
- )
- == expected_tokens[mistral_tokenizer.is_tekken]
- )
def test_decode_empty(
self,
@@ -1140,6 +1245,45 @@ class TestMistralTokenizer:
== ""
)
+ @pytest.mark.parametrize(
+ "skip_special_tokens,expected_tokens",
+ (
+ (
+ False,
+ (
+ ["[INST]▁Hello▁world▁![/INST]▁Hello"],
+ ["[INST]Hello world ![/INST]Hello"],
+ ),
+ ),
+ (True, (["Hello world ! Hello"], ["Hello world !Hello"])),
+ ),
+ )
+ def test_batch_decode(
+ self,
+ mistral_tokenizer: MistralTokenizer,
+ skip_special_tokens: bool,
+ expected_tokens: tuple[str, str],
+ ):
+ ids = (
+ [[1, 3, 23325, 2294, 1686, 4, 23325, 2]],
+ [[1, 3, 22177, 4304, 2662, 4, 22177, 2]],
+ )
+ assert (
+ mistral_tokenizer.batch_decode(
+ ids[mistral_tokenizer.is_tekken],
+ skip_special_tokens=skip_special_tokens,
+ )
+ == expected_tokens[mistral_tokenizer.is_tekken]
+ )
+
+ def test_batch_decode_empty(
+ self,
+ mistral_tokenizer: MistralTokenizer,
+ ):
+ assert mistral_tokenizer.batch_decode(
+ [[]],
+ ) == [""]
+
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
tokens = (
[
diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py
index 80d53a543f149..c6d6f705f535c 100644
--- a/vllm/config/speculative.py
+++ b/vllm/config/speculative.py
@@ -167,6 +167,7 @@ class SpeculativeConfig:
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
+ initial_architecture = hf_config.architectures[0]
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
@@ -226,6 +227,9 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
+ if initial_architecture == "MistralLarge3ForCausalLM":
+ hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
+
return hf_config
def __post_init__(self):
diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
index 7e2d67a1fb659..89b882d6c8475 100644
--- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
@@ -80,7 +80,7 @@ class MistralToolParser(ToolParser):
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(
- r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL
+ r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)?", re.DOTALL
)
else:
self.fn_name_regex = None
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
new file mode 100644
index 0000000000000..a9f24c20a25a2
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "96": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2
+ },
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2
+ },
+ "256": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2
+ },
+ "512": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py
index 6ebfa47a9dc3f..dad960160f2ad 100644
--- a/vllm/model_executor/layers/mla.py
+++ b/vllm/model_executor/layers/mla.py
@@ -111,6 +111,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
+ llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
q_c = None
kv_lora = None
@@ -159,6 +160,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
hidden_states, q_c, positions, self.indexer_rope_emb
)
+ if llama_4_scaling is not None:
+ q *= llama_4_scaling
+
attn_out = self.mla_attn(
q,
kv_c_normed,
diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py
index 0f10bff6ac4f5..aa6ece30026d3 100644
--- a/vllm/model_executor/layers/rotary_embedding/__init__.py
+++ b/vllm/model_executor/layers/rotary_embedding/__init__.py
@@ -238,7 +238,7 @@ def get_rope(
dtype,
**extra_kwargs,
)
- elif scaling_type == "deepseek_yarn":
+ elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 73cac2556c55a..d8a081af125c5 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
+def _get_llama_4_scaling(
+ original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
+) -> torch.Tensor:
+ scaling = 1 + scaling_beta * torch.log(
+ 1 + torch.floor(positions / original_max_position_embeddings)
+ )
+ # Broadcast over num_heads and head_dim
+ return scaling[..., None, None]
+
+
class DeepseekV2Attention(nn.Module):
def __init__(
self,
@@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)
if config.rope_parameters["rope_type"] != "default":
- config.rope_parameters["rope_type"] = "deepseek_yarn"
+ config.rope_parameters["rope_type"] = (
+ "deepseek_yarn"
+ if config.rope_parameters.get("apply_yarn_scaling", True)
+ else "deepseek_llama_scaling"
+ )
self.rotary_emb = get_rope(
qk_rope_head_dim,
@@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module):
is_neox_style=False,
)
- if config.rope_parameters["rope_type"] != "default":
+ if (
+ config.rope_parameters["rope_type"] != "default"
+ and config.rope_parameters["rope_type"] == "deepseek_yarn"
+ ):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
+ llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
@@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module):
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
+
+ # Apply llama 4 scaling if provided
+ if llama_4_scaling is not None:
+ q *= llama_4_scaling
+
# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim], value=0
@@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module):
)
if config.rope_parameters["rope_type"] != "default":
- config.rope_parameters["rope_type"] = "deepseek_yarn"
+ config.rope_parameters["rope_type"] = (
+ "deepseek_yarn"
+ if config.rope_parameters.get("apply_yarn_scaling", True)
+ else "deepseek_llama_scaling"
+ )
+
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
@@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module):
rope_parameters=config.rope_parameters,
is_neox_style=False,
)
- if config.rope_parameters["rope_type"] != "default":
+
+ if (
+ config.rope_parameters["rope_type"] != "default"
+ and config.rope_parameters["rope_type"] == "deepseek_yarn"
+ ):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
+ llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
- return self.mla_attn(positions, hidden_states)
+ return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV2DecoderLayer(nn.Module):
@@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
+ llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
@@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
+ llama_4_scaling=llama_4_scaling,
)
if (
@@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
+ # Compute llama 4 scaling once per forward pass if enabled
+ llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
+ llama_4_scaling: torch.Tensor | None
+ if llama_4_scaling_config is not None:
+ llama_4_scaling = _get_llama_4_scaling(
+ original_max_position_embeddings=llama_4_scaling_config[
+ "original_max_position_embeddings"
+ ],
+ scaling_beta=llama_4_scaling_config["beta"],
+ positions=positions,
+ )
+ else:
+ llama_4_scaling = None
+
for layer in islice(self.layers, self.start_layer, self.end_layer):
- hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, residual = layer(
+ positions, hidden_states, residual, llama_4_scaling
+ )
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
+ model_cls = DeepseekV2Model
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM(
"kv_a_proj_with_mqa",
]
- self.model = DeepseekV2Model(
+ self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
diff --git a/vllm/model_executor/models/mistral_large_3.py b/vllm/model_executor/models/mistral_large_3.py
new file mode 100644
index 0000000000000..ff7e9b60c1d3c
--- /dev/null
+++ b/vllm/model_executor/models/mistral_large_3.py
@@ -0,0 +1,63 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable
+
+import regex as re
+import torch
+
+from vllm.model_executor.models.deepseek_v2 import DeepseekV3ForCausalLM
+
+
+class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM):
+ # fmt: off
+ remapping = {
+ r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501
+ r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501
+ r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501
+ r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501
+ r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501
+ r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501
+ r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501
+ r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501
+ r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501
+ r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501
+ r"norm\.weight": "model.norm.weight", # noqa: E501
+ r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501
+ r"output\.weight": "lm_head.weight", # noqa: E501
+ }
+ # fmt: on
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ return super().load_weights(map(self._remap_mistral_to_ds, weights))
+
+ def _remap_mistral_to_ds(
+ self, weight: tuple[str, torch.Tensor]
+ ) -> tuple[str, torch.Tensor]:
+ """Remap Mistral parameters to DeepseekV2 parameters."""
+ name, loaded_weight = weight
+
+ for k, v in self.remapping.items():
+ match = re.fullmatch(k, name)
+ if match:
+ name = re.sub(k, v, name)
+ break
+ else:
+ raise ValueError(f"Cannot remap {name}")
+
+ # Remapping scale names. We could do this in the regex above but it
+ # would triple the number of lines for most layers.
+ if name.endswith(".qscale_act"):
+ name = re.sub(r"\.qscale_act$", ".input_scale", name)
+ elif name.endswith(".qscale_weight"):
+ name = re.sub(r"\.qscale_weight$", ".weight_scale", name)
+
+ return name, loaded_weight
diff --git a/vllm/model_executor/models/mistral_large_3_eagle.py b/vllm/model_executor/models/mistral_large_3_eagle.py
new file mode 100644
index 0000000000000..e3ca9e4ca82d0
--- /dev/null
+++ b/vllm/model_executor/models/mistral_large_3_eagle.py
@@ -0,0 +1,165 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Iterable
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.distributed.parallel_state import get_pp_group
+from vllm.logger import init_logger
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import RowParallelLinear
+from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
+from vllm.model_executor.models.deepseek_v2 import (
+ DeepseekV2DecoderLayer,
+ DeepseekV2Model,
+)
+from vllm.model_executor.models.interfaces import MultiModalEmbeddings
+from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
+from vllm.multimodal.inputs import NestedTensors
+
+from .utils import (
+ _merge_multimodal_embeddings,
+ make_empty_intermediate_tensors_factory,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+
+@support_torch_compile
+class EagleMistralLarge3Model(DeepseekV2Model):
+ def __init__(
+ self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0
+ ):
+ nn.Module.__init__(self)
+
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.vllm_config = vllm_config
+
+ self.vocab_size = config.vocab_size
+
+ assert get_pp_group().world_size == 1
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens",
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ DeepseekV2DecoderLayer(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
+ )
+ for i in range(self.config.num_hidden_layers)
+ ]
+ )
+ self.start_layer = 0
+ self.end_layer = self.config.num_hidden_layers
+
+ self.fc = RowParallelLinear(
+ self.config.hidden_size * 2,
+ self.config.hidden_size,
+ bias=False,
+ input_is_parallel=False,
+ quant_config=quant_config,
+ return_bias=False,
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_input_ids(input_ids)
+ inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
+ output = super().forward(
+ input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds
+ )
+ assert isinstance(output, torch.Tensor)
+ return output
+
+
+class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
+ remapping = MistralLarge3ForCausalLM.remapping | {
+ r"eagle_linear\.weight": r"model.fc.weight",
+ r"eagle_linear\.qscale_act": r"model.fc.input_scale",
+ r"eagle_linear\.qscale_weight": r"model.fc.weight_scale",
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ target_layer_num = vllm_config.model_config.get_num_layers(
+ vllm_config.parallel_config
+ )
+ vllm_config.model_config = vllm_config.speculative_config.draft_model_config
+ # draft model quantization config may differ from target model
+ self.quant_config = VllmConfig.get_quantization_config(
+ vllm_config.speculative_config.draft_model_config, vllm_config.load_config
+ )
+ vllm_config.quant_config = self.quant_config
+ self.model_cls = partial(
+ EagleMistralLarge3Model, start_layer_id=target_layer_num
+ )
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
+ *,
+ is_multimodal: torch.Tensor | None = None,
+ handle_oov_mm_token: bool = False,
+ ) -> torch.Tensor:
+ inputs_embeds = super().embed_input_ids(input_ids)
+
+ if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
+ return inputs_embeds
+
+ assert is_multimodal is not None
+
+ return _merge_multimodal_embeddings(
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds)
+ return hidden_states, hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ # Pretend we've loaded the embedding and lm_head weights
+ # (later copied from target model)
+ return super().load_weights(weights) | {
+ "model.embed_tokens.weight",
+ "lm_head.weight",
+ }
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: NestedTensors | None = None,
+ is_multimodal: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 73a61f1148b50..d3b6268e7647b 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -145,6 +145,7 @@ _TEXT_GENERATION_MODELS = {
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
+ "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
@@ -424,6 +425,10 @@ _SPECULATIVE_DECODING_MODELS = {
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
+ "EagleMistralLarge3ForCausalLM": (
+ "mistral_large_3_eagle",
+ "EagleMistralLarge3ForCausalLM",
+ ),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py
index 7e6745004b01f..96d1e78ce9f17 100644
--- a/vllm/tokenizers/mistral.py
+++ b/vllm/tokenizers/mistral.py
@@ -97,6 +97,8 @@ def _prepare_apply_chat_template_tools_and_messages(
continue_final_message: bool = False,
add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
+ from mistral_common.protocol.instruct.tool_calls import Function, Tool
+
if add_generation_prompt and continue_final_message:
raise ValueError(
"Cannot set both `add_generation_prompt` and "
@@ -139,6 +141,33 @@ def _prepare_apply_chat_template_tools_and_messages(
if function.get("description") is None:
function["description"] = ""
+ # We filter not supported arguments to avoid throwing an error.
+ # TODO(juliendenize): remove this once OpenAI API is better supported by
+ # `mistral-common`.
+ tools_fields = set(Tool.model_fields.keys())
+ function_fields = set(Function.model_fields.keys())
+ for tool in tools:
+ tool_keys = list(tool.keys())
+ for tool_key in tool_keys:
+ if tool_key not in tools_fields:
+ tool.pop(tool_key)
+ logger.warning_once(
+ f"'{tool_key}' is not supported by mistral-common for tools. "
+ "It has been poped from the tool definition."
+ )
+ if tool["type"] == "function":
+ function_keys = list(tool["function"].keys())
+ for function_key in function_keys:
+ if function_key not in function_fields:
+ tool["function"].pop(function_key)
+ logger.warning_once(
+ f"'{function_key}' is not supported by mistral-common "
+ "for function tools. It has been poped from the "
+ "function definition."
+ )
+ else:
+ raise ValueError("mistral-common only supports function tools.")
+
return messages, tools
@@ -410,6 +439,13 @@ class MistralTokenizer(TokenizerLike):
ids, skip_special_tokens=skip_special_tokens
)
+ def batch_decode(
+ self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
+ ) -> str:
+ return self.transformers_tokenizer.batch_decode(
+ ids, skip_special_tokens=skip_special_tokens
+ )
+
def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py
index f5dc9ddfbc575..ce428e567c844 100644
--- a/vllm/transformers_utils/configs/eagle.py
+++ b/vllm/transformers_utils/configs/eagle.py
@@ -82,3 +82,9 @@ class EAGLEConfig(PretrainedConfig):
pretrained_model_name_or_path, **kwargs
)
return cls.from_dict(config_dict, **kwargs)
+
+ def to_json_string(self, use_diff: bool = True) -> str:
+ # we override use_diff to False as initializing
+ # EAGLEConfig with default arguments is not supported
+ del use_diff
+ return super().to_json_string(use_diff=False)
diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py
index 966737aad0867..d59169d95f0c9 100644
--- a/vllm/transformers_utils/configs/mistral.py
+++ b/vllm/transformers_utils/configs/mistral.py
@@ -18,9 +18,31 @@ def adapt_config_dict(
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
+ is_moe = bool(config_dict.get("moe"))
+ is_mistral_large_3 = (
+ is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
+ )
if config_dict.get("model_type") == "mamba":
config_dict["architectures"] = ["Mamba2ForCausalLM"]
- elif bool(config_dict.get("moe")):
+ elif is_moe and is_mistral_large_3:
+ config_dict = _remap_moe_args(config_dict)
+ config_dict["model_type"] = "deepseek_v3"
+ config_dict["architectures"] = ["MistralLarge3ForCausalLM"]
+
+ assert "llama_4_scaling" in config_dict, (
+ "MistralLarge3 expect llama4 scaling config."
+ )
+ llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
+ assert all(
+ [
+ key in config_dict["llama_4_scaling"]
+ for key in llama_4_scaling_config_keys
+ ]
+ ), (
+ "llama_4_scaling config should define the keys: "
+ f"{','.join(llama_4_scaling_config_keys)}"
+ )
+ elif is_moe:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
@@ -140,17 +162,20 @@ def _remap_general_mistral_args(config: dict) -> dict:
def _remap_mistral_quantization_args(config: dict) -> dict:
- quantization = config.get("quantization", {})
- if quantization.get("qformat_weight") == "fp8_e4m3":
- # This maps to the FP8 static per-tensor quantization scheme
- quantization_config = {"quant_method": "fp8", "activation_scheme": "static"}
- elif quantization.get("quant_method") == "compressed-tensors":
- # Pass through the quantization config to compressed-tensors
- quantization_config = quantization
- else:
- raise ValueError(f"Found unknown quantization='{quantization}' in config")
-
- config["quantization_config"] = quantization_config
+ if config.get("quantization"):
+ quantization = config.pop("quantization", {})
+ if quantization.get("qformat_weight") == "fp8_e4m3":
+ qscheme_act = quantization.get("qscheme_act")
+ assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
+ "Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
+ )
+ is_dynamic = qscheme_act == "NO_SCALES"
+ config["quantization_config"] = {
+ "quant_method": "fp8",
+ "activation_scheme": "dynamic" if is_dynamic else "static",
+ }
+ else:
+ raise ValueError(f"Found unknown quantization='{quantization}' in config")
return config
@@ -183,3 +208,28 @@ def _remap_mistral_audio_args(config: dict) -> dict:
if quant_config:
config["quantization_config"] = quant_config
return config
+
+
+def _remap_moe_args(config: dict) -> dict:
+ moe_config_map = {
+ "route_every_n": "moe_layer_freq",
+ "first_k_dense_replace": "first_k_dense_replace",
+ "num_experts_per_tok": "num_experts_per_tok",
+ "num_experts": "n_routed_experts",
+ "expert_hidden_dim": "moe_intermediate_size",
+ "routed_scale": "routed_scaling_factor",
+ "num_shared_experts": "n_shared_experts",
+ "num_expert_groups": "n_group",
+ "num_expert_groups_per_tok": "topk_group",
+ }
+ moe_config = config.get("moe", {})
+ for old_name, new_name in moe_config_map.items():
+ if old_name in moe_config:
+ value = moe_config.pop(old_name)
+ config[new_name] = value
+
+ config["topk_method"] = None
+ config["norm_topk_prob"] = True
+ config["scoring_func"] = "softmax"
+
+ return config
diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py
index d7111d52dd8a1..1c7845a14b742 100644
--- a/vllm/v1/spec_decode/eagle.py
+++ b/vllm/v1/spec_decode/eagle.py
@@ -1016,6 +1016,10 @@ class EagleProposer:
"Qwen3VLForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
+ elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
+ self.model.config.image_token_index = (
+ target_model.config.vision_config.image_token_id
+ )
else:
self.model.config.image_token_index = (
target_model.config.image_token_index