mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 06:35:41 +08:00
Add Mistral Large 3 and Ministral 3 (#29757)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Signed-off-by: Mickael Seznec <mickael@mistral.ai> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
parent
8bbcf8b6e7
commit
d8c6210eea
@ -417,7 +417,8 @@ th {
|
|||||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
|
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
|
||||||
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
|
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
|
||||||
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
|
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
|
||||||
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
|
| `MistralForCausalLM` | Ministral-3, Mistral, Mistral-Instruct | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||||
|
| `MistralLarge3ForCausalLM` | Mistral-Large-3-675B-Base-2512, Mistral-Large-3-675B-Instruct-2512 | `mistralai/Mistral-Large-3-675B-Base-2512`, `mistralai/Mistral-Large-3-675B-Instruct-2512`, etc. | ✅︎ | ✅︎ |
|
||||||
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
|
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ |
|
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ |
|
||||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ |
|
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ |
|
||||||
@ -711,7 +712,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
|||||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
|
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
|
||||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
|
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
|
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
|
||||||
| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ |
|
| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ |
|
||||||
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ |
|
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ |
|
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ |
|
||||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ |
|
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -358,6 +358,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
),
|
),
|
||||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||||
|
"MistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||||
|
"mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False
|
||||||
|
),
|
||||||
"MixtralForCausalLM": _HfExamplesInfo(
|
"MixtralForCausalLM": _HfExamplesInfo(
|
||||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
{"tiny": "TitanML/tiny-mixtral"},
|
{"tiny": "TitanML/tiny-mixtral"},
|
||||||
@ -770,7 +773,13 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
),
|
),
|
||||||
"PixtralForConditionalGeneration": _HfExamplesInfo(
|
"PixtralForConditionalGeneration": _HfExamplesInfo(
|
||||||
"mistralai/Pixtral-12B-2409",
|
"mistralai/Pixtral-12B-2409",
|
||||||
|
extras={
|
||||||
|
"mistral-large-3": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4",
|
||||||
|
"ministral-3": "mistralai/Ministral-3-3B-Instruct-2512",
|
||||||
|
},
|
||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral",
|
||||||
|
# TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available.
|
||||||
|
is_available_online=False,
|
||||||
),
|
),
|
||||||
"QwenVLForConditionalGeneration": _HfExamplesInfo(
|
"QwenVLForConditionalGeneration": _HfExamplesInfo(
|
||||||
"Qwen/Qwen-VL",
|
"Qwen/Qwen-VL",
|
||||||
@ -870,6 +879,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
use_original_num_layers=True,
|
use_original_num_layers=True,
|
||||||
max_model_len=10240,
|
max_model_len=10240,
|
||||||
),
|
),
|
||||||
|
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||||
|
"mistralai/Mistral-Large-3-675B-Instruct-2512",
|
||||||
|
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
|
||||||
|
is_available_online=False,
|
||||||
|
),
|
||||||
"LlamaForCausalLMEagle3": _HfExamplesInfo(
|
"LlamaForCausalLMEagle3": _HfExamplesInfo(
|
||||||
"Qwen/Qwen3-8B",
|
"Qwen/Qwen3-8B",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|||||||
@ -91,6 +91,118 @@ from vllm.tokenizers.mistral import (
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the current local date and time?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time.",
|
||||||
|
"unsupported_field": False,
|
||||||
|
"name": "get_current_time",
|
||||||
|
"parameters": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time.",
|
||||||
|
"unsupported_field2": False,
|
||||||
|
"name": "get_current_time",
|
||||||
|
"parameters": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the current local date and time?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time.",
|
||||||
|
"name": "get_current_time",
|
||||||
|
"parameters": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time.",
|
||||||
|
"name": "get_current_time",
|
||||||
|
"parameters": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the current local date and time?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"unsupported_field": False,
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time.",
|
||||||
|
"name": "get_current_time",
|
||||||
|
"parameters": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"unsupported_field2": False,
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time 2.",
|
||||||
|
"name": "get_current_time2",
|
||||||
|
"parameters": {"a": "1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the current local date and time?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time.",
|
||||||
|
"name": "get_current_time",
|
||||||
|
"parameters": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time 2.",
|
||||||
|
"name": "get_current_time2",
|
||||||
|
"parameters": {"a": "1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_prepare_apply_chat_template_tools_and_messages(
|
def test_prepare_apply_chat_template_tools_and_messages(
|
||||||
@ -1108,13 +1220,6 @@ class TestMistralTokenizer:
|
|||||||
)
|
)
|
||||||
== expected_tokens[mistral_tokenizer.is_tekken]
|
== expected_tokens[mistral_tokenizer.is_tekken]
|
||||||
)
|
)
|
||||||
assert (
|
|
||||||
mistral_tokenizer.decode(
|
|
||||||
ids[mistral_tokenizer.is_tekken],
|
|
||||||
skip_special_tokens=skip_special_tokens,
|
|
||||||
)
|
|
||||||
== expected_tokens[mistral_tokenizer.is_tekken]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_decode_empty(
|
def test_decode_empty(
|
||||||
self,
|
self,
|
||||||
@ -1140,6 +1245,45 @@ class TestMistralTokenizer:
|
|||||||
== "<s>"
|
== "<s>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"skip_special_tokens,expected_tokens",
|
||||||
|
(
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
(
|
||||||
|
["<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>"],
|
||||||
|
["<s>[INST]Hello world ![/INST]Hello</s>"],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(True, (["Hello world ! Hello"], ["Hello world !Hello"])),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_batch_decode(
|
||||||
|
self,
|
||||||
|
mistral_tokenizer: MistralTokenizer,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
expected_tokens: tuple[str, str],
|
||||||
|
):
|
||||||
|
ids = (
|
||||||
|
[[1, 3, 23325, 2294, 1686, 4, 23325, 2]],
|
||||||
|
[[1, 3, 22177, 4304, 2662, 4, 22177, 2]],
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mistral_tokenizer.batch_decode(
|
||||||
|
ids[mistral_tokenizer.is_tekken],
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
)
|
||||||
|
== expected_tokens[mistral_tokenizer.is_tekken]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_batch_decode_empty(
|
||||||
|
self,
|
||||||
|
mistral_tokenizer: MistralTokenizer,
|
||||||
|
):
|
||||||
|
assert mistral_tokenizer.batch_decode(
|
||||||
|
[[]],
|
||||||
|
) == [""]
|
||||||
|
|
||||||
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
|
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
|
||||||
tokens = (
|
tokens = (
|
||||||
[
|
[
|
||||||
|
|||||||
@ -167,6 +167,7 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||||
|
initial_architecture = hf_config.architectures[0]
|
||||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||||
hf_config.model_type = "deepseek_mtp"
|
hf_config.model_type = "deepseek_mtp"
|
||||||
if hf_config.model_type == "deepseek_mtp":
|
if hf_config.model_type == "deepseek_mtp":
|
||||||
@ -226,6 +227,9 @@ class SpeculativeConfig:
|
|||||||
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
|
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if initial_architecture == "MistralLarge3ForCausalLM":
|
||||||
|
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
|
||||||
|
|
||||||
return hf_config
|
return hf_config
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class MistralToolParser(ToolParser):
|
|||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||||
self.fn_name_regex = re.compile(
|
self.fn_name_regex = re.compile(
|
||||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL
|
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)?", re.DOTALL
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.fn_name_regex = None
|
self.fn_name_regex = None
|
||||||
|
|||||||
@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 2
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -111,6 +111,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
llama_4_scaling: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q_c = None
|
q_c = None
|
||||||
kv_lora = None
|
kv_lora = None
|
||||||
@ -159,6 +160,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
|||||||
hidden_states, q_c, positions, self.indexer_rope_emb
|
hidden_states, q_c, positions, self.indexer_rope_emb
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if llama_4_scaling is not None:
|
||||||
|
q *= llama_4_scaling
|
||||||
|
|
||||||
attn_out = self.mla_attn(
|
attn_out = self.mla_attn(
|
||||||
q,
|
q,
|
||||||
kv_c_normed,
|
kv_c_normed,
|
||||||
|
|||||||
@ -238,7 +238,7 @@ def get_rope(
|
|||||||
dtype,
|
dtype,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
elif scaling_type == "deepseek_yarn":
|
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
|
||||||
scaling_factor = rope_parameters["factor"]
|
scaling_factor = rope_parameters["factor"]
|
||||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||||
# assert max_position == original_max_position * scaling_factor
|
# assert max_position == original_max_position * scaling_factor
|
||||||
|
|||||||
@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|||||||
return 0.1 * mscale * math.log(scale) + 1.0
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llama_4_scaling(
|
||||||
|
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
scaling = 1 + scaling_beta * torch.log(
|
||||||
|
1 + torch.floor(positions / original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
# Broadcast over num_heads and head_dim
|
||||||
|
return scaling[..., None, None]
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Attention(nn.Module):
|
class DeepseekV2Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
if config.rope_parameters["rope_type"] != "default":
|
if config.rope_parameters["rope_type"] != "default":
|
||||||
config.rope_parameters["rope_type"] = "deepseek_yarn"
|
config.rope_parameters["rope_type"] = (
|
||||||
|
"deepseek_yarn"
|
||||||
|
if config.rope_parameters.get("apply_yarn_scaling", True)
|
||||||
|
else "deepseek_llama_scaling"
|
||||||
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
qk_rope_head_dim,
|
qk_rope_head_dim,
|
||||||
@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.rope_parameters["rope_type"] != "default":
|
if (
|
||||||
|
config.rope_parameters["rope_type"] != "default"
|
||||||
|
and config.rope_parameters["rope_type"] == "deepseek_yarn"
|
||||||
|
):
|
||||||
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
||||||
scaling_factor = config.rope_parameters["factor"]
|
scaling_factor = config.rope_parameters["factor"]
|
||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
llama_4_scaling: torch.Tensor | None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q = self.q_a_proj(hidden_states)[0]
|
q = self.q_a_proj(hidden_states)[0]
|
||||||
@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
k = torch.empty_like(q)
|
k = torch.empty_like(q)
|
||||||
k[..., : self.qk_nope_head_dim] = k_nope
|
k[..., : self.qk_nope_head_dim] = k_nope
|
||||||
k[..., self.qk_nope_head_dim :] = k_pe
|
k[..., self.qk_nope_head_dim :] = k_pe
|
||||||
|
|
||||||
|
# Apply llama 4 scaling if provided
|
||||||
|
if llama_4_scaling is not None:
|
||||||
|
q *= llama_4_scaling
|
||||||
|
|
||||||
# padding value to qk_head_dim for alignment
|
# padding value to qk_head_dim for alignment
|
||||||
v = torch.nn.functional.pad(
|
v = torch.nn.functional.pad(
|
||||||
v, [0, self.qk_head_dim - self.v_head_dim], value=0
|
v, [0, self.qk_head_dim - self.v_head_dim], value=0
|
||||||
@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.rope_parameters["rope_type"] != "default":
|
if config.rope_parameters["rope_type"] != "default":
|
||||||
config.rope_parameters["rope_type"] = "deepseek_yarn"
|
config.rope_parameters["rope_type"] = (
|
||||||
|
"deepseek_yarn"
|
||||||
|
if config.rope_parameters.get("apply_yarn_scaling", True)
|
||||||
|
else "deepseek_llama_scaling"
|
||||||
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
qk_rope_head_dim,
|
qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
rope_parameters=config.rope_parameters,
|
rope_parameters=config.rope_parameters,
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
if config.rope_parameters["rope_type"] != "default":
|
|
||||||
|
if (
|
||||||
|
config.rope_parameters["rope_type"] != "default"
|
||||||
|
and config.rope_parameters["rope_type"] == "deepseek_yarn"
|
||||||
|
):
|
||||||
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
||||||
scaling_factor = config.rope_parameters["factor"]
|
scaling_factor = config.rope_parameters["factor"]
|
||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
llama_4_scaling: torch.Tensor | None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.mla_attn(positions, hidden_states)
|
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: torch.Tensor | None,
|
residual: torch.Tensor | None,
|
||||||
|
llama_4_scaling: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
llama_4_scaling=llama_4_scaling,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
# Compute llama 4 scaling once per forward pass if enabled
|
||||||
|
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
|
||||||
|
llama_4_scaling: torch.Tensor | None
|
||||||
|
if llama_4_scaling_config is not None:
|
||||||
|
llama_4_scaling = _get_llama_4_scaling(
|
||||||
|
original_max_position_embeddings=llama_4_scaling_config[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
scaling_beta=llama_4_scaling_config["beta"],
|
||||||
|
positions=positions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llama_4_scaling = None
|
||||||
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(
|
||||||
|
positions, hidden_states, residual, llama_4_scaling
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
|
|||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
}
|
}
|
||||||
|
model_cls = DeepseekV2Model
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM(
|
|||||||
"kv_a_proj_with_mqa",
|
"kv_a_proj_with_mqa",
|
||||||
]
|
]
|
||||||
|
|
||||||
self.model = DeepseekV2Model(
|
self.model = self.model_cls(
|
||||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||||
)
|
)
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
|
|||||||
63
vllm/model_executor/models/mistral_large_3.py
Normal file
63
vllm/model_executor/models/mistral_large_3.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV3ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM):
|
||||||
|
# fmt: off
|
||||||
|
remapping = {
|
||||||
|
r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501
|
||||||
|
r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501
|
||||||
|
r"norm\.weight": "model.norm.weight", # noqa: E501
|
||||||
|
r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501
|
||||||
|
r"output\.weight": "lm_head.weight", # noqa: E501
|
||||||
|
}
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
return super().load_weights(map(self._remap_mistral_to_ds, weights))
|
||||||
|
|
||||||
|
def _remap_mistral_to_ds(
|
||||||
|
self, weight: tuple[str, torch.Tensor]
|
||||||
|
) -> tuple[str, torch.Tensor]:
|
||||||
|
"""Remap Mistral parameters to DeepseekV2 parameters."""
|
||||||
|
name, loaded_weight = weight
|
||||||
|
|
||||||
|
for k, v in self.remapping.items():
|
||||||
|
match = re.fullmatch(k, name)
|
||||||
|
if match:
|
||||||
|
name = re.sub(k, v, name)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot remap {name}")
|
||||||
|
|
||||||
|
# Remapping scale names. We could do this in the regex above but it
|
||||||
|
# would triple the number of lines for most layers.
|
||||||
|
if name.endswith(".qscale_act"):
|
||||||
|
name = re.sub(r"\.qscale_act$", ".input_scale", name)
|
||||||
|
elif name.endswith(".qscale_weight"):
|
||||||
|
name = re.sub(r"\.qscale_weight$", ".weight_scale", name)
|
||||||
|
|
||||||
|
return name, loaded_weight
|
||||||
165
vllm/model_executor/models/mistral_large_3_eagle.py
Normal file
165
vllm/model_executor/models/mistral_large_3_eagle.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from vllm.model_executor.models.deepseek_v2 import (
|
||||||
|
DeepseekV2DecoderLayer,
|
||||||
|
DeepseekV2Model,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
|
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
|
||||||
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
_merge_multimodal_embeddings,
|
||||||
|
make_empty_intermediate_tensors_factory,
|
||||||
|
maybe_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class EagleMistralLarge3Model(DeepseekV2Model):
|
||||||
|
def __init__(
|
||||||
|
self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0
|
||||||
|
):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
assert get_pp_group().world_size == 1
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.embed_tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV2DecoderLayer(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||||
|
)
|
||||||
|
for i in range(self.config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.start_layer = 0
|
||||||
|
self.end_layer = self.config.num_hidden_layers
|
||||||
|
|
||||||
|
self.fc = RowParallelLinear(
|
||||||
|
self.config.hidden_size * 2,
|
||||||
|
self.config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_input_ids(input_ids)
|
||||||
|
inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
|
||||||
|
output = super().forward(
|
||||||
|
input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds
|
||||||
|
)
|
||||||
|
assert isinstance(output, torch.Tensor)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
|
||||||
|
remapping = MistralLarge3ForCausalLM.remapping | {
|
||||||
|
r"eagle_linear\.weight": r"model.fc.weight",
|
||||||
|
r"eagle_linear\.qscale_act": r"model.fc.input_scale",
|
||||||
|
r"eagle_linear\.qscale_weight": r"model.fc.weight_scale",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||||
|
vllm_config.parallel_config
|
||||||
|
)
|
||||||
|
vllm_config.model_config = vllm_config.speculative_config.draft_model_config
|
||||||
|
# draft model quantization config may differ from target model
|
||||||
|
self.quant_config = VllmConfig.get_quantization_config(
|
||||||
|
vllm_config.speculative_config.draft_model_config, vllm_config.load_config
|
||||||
|
)
|
||||||
|
vllm_config.quant_config = self.quant_config
|
||||||
|
self.model_cls = partial(
|
||||||
|
EagleMistralLarge3Model, start_layer_id=target_layer_num
|
||||||
|
)
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: torch.Tensor | None = None,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = super().embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
assert is_multimodal is not None
|
||||||
|
|
||||||
|
return _merge_multimodal_embeddings(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
|
is_multimodal=is_multimodal,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||||
|
return hidden_states, hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
# Pretend we've loaded the embedding and lm_head weights
|
||||||
|
# (later copied from target model)
|
||||||
|
return super().load_weights(weights) | {
|
||||||
|
"model.embed_tokens.weight",
|
||||||
|
"lm_head.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
def embed_input_ids(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: NestedTensors | None = None,
|
||||||
|
is_multimodal: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model.embed_input_ids(input_ids)
|
||||||
@ -145,6 +145,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||||
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
|
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
|
||||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
"MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
|
||||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||||
# transformers's mpt class has lower case
|
# transformers's mpt class has lower case
|
||||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||||
@ -424,6 +425,10 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
|
"EagleMistralLarge3ForCausalLM": (
|
||||||
|
"mistral_large_3_eagle",
|
||||||
|
"EagleMistralLarge3ForCausalLM",
|
||||||
|
),
|
||||||
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
||||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||||
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
||||||
|
|||||||
@ -97,6 +97,8 @@ def _prepare_apply_chat_template_tools_and_messages(
|
|||||||
continue_final_message: bool = False,
|
continue_final_message: bool = False,
|
||||||
add_generation_prompt: bool = False,
|
add_generation_prompt: bool = False,
|
||||||
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
|
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
|
||||||
|
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||||
|
|
||||||
if add_generation_prompt and continue_final_message:
|
if add_generation_prompt and continue_final_message:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot set both `add_generation_prompt` and "
|
"Cannot set both `add_generation_prompt` and "
|
||||||
@ -139,6 +141,33 @@ def _prepare_apply_chat_template_tools_and_messages(
|
|||||||
if function.get("description") is None:
|
if function.get("description") is None:
|
||||||
function["description"] = ""
|
function["description"] = ""
|
||||||
|
|
||||||
|
# We filter not supported arguments to avoid throwing an error.
|
||||||
|
# TODO(juliendenize): remove this once OpenAI API is better supported by
|
||||||
|
# `mistral-common`.
|
||||||
|
tools_fields = set(Tool.model_fields.keys())
|
||||||
|
function_fields = set(Function.model_fields.keys())
|
||||||
|
for tool in tools:
|
||||||
|
tool_keys = list(tool.keys())
|
||||||
|
for tool_key in tool_keys:
|
||||||
|
if tool_key not in tools_fields:
|
||||||
|
tool.pop(tool_key)
|
||||||
|
logger.warning_once(
|
||||||
|
f"'{tool_key}' is not supported by mistral-common for tools. "
|
||||||
|
"It has been poped from the tool definition."
|
||||||
|
)
|
||||||
|
if tool["type"] == "function":
|
||||||
|
function_keys = list(tool["function"].keys())
|
||||||
|
for function_key in function_keys:
|
||||||
|
if function_key not in function_fields:
|
||||||
|
tool["function"].pop(function_key)
|
||||||
|
logger.warning_once(
|
||||||
|
f"'{function_key}' is not supported by mistral-common "
|
||||||
|
"for function tools. It has been poped from the "
|
||||||
|
"function definition."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("mistral-common only supports function tools.")
|
||||||
|
|
||||||
return messages, tools
|
return messages, tools
|
||||||
|
|
||||||
|
|
||||||
@ -410,6 +439,13 @@ class MistralTokenizer(TokenizerLike):
|
|||||||
ids, skip_special_tokens=skip_special_tokens
|
ids, skip_special_tokens=skip_special_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def batch_decode(
|
||||||
|
self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
|
||||||
|
) -> str:
|
||||||
|
return self.transformers_tokenizer.batch_decode(
|
||||||
|
ids, skip_special_tokens=skip_special_tokens
|
||||||
|
)
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||||
from mistral_common.tokens.tokenizers.base import (
|
from mistral_common.tokens.tokenizers.base import (
|
||||||
SpecialTokenPolicy,
|
SpecialTokenPolicy,
|
||||||
|
|||||||
@ -82,3 +82,9 @@ class EAGLEConfig(PretrainedConfig):
|
|||||||
pretrained_model_name_or_path, **kwargs
|
pretrained_model_name_or_path, **kwargs
|
||||||
)
|
)
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
def to_json_string(self, use_diff: bool = True) -> str:
|
||||||
|
# we override use_diff to False as initializing
|
||||||
|
# EAGLEConfig with default arguments is not supported
|
||||||
|
del use_diff
|
||||||
|
return super().to_json_string(use_diff=False)
|
||||||
|
|||||||
@ -18,9 +18,31 @@ def adapt_config_dict(
|
|||||||
if bool(config_dict.get("quantization")):
|
if bool(config_dict.get("quantization")):
|
||||||
config_dict = _remap_mistral_quantization_args(config_dict)
|
config_dict = _remap_mistral_quantization_args(config_dict)
|
||||||
|
|
||||||
|
is_moe = bool(config_dict.get("moe"))
|
||||||
|
is_mistral_large_3 = (
|
||||||
|
is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
|
||||||
|
)
|
||||||
if config_dict.get("model_type") == "mamba":
|
if config_dict.get("model_type") == "mamba":
|
||||||
config_dict["architectures"] = ["Mamba2ForCausalLM"]
|
config_dict["architectures"] = ["Mamba2ForCausalLM"]
|
||||||
elif bool(config_dict.get("moe")):
|
elif is_moe and is_mistral_large_3:
|
||||||
|
config_dict = _remap_moe_args(config_dict)
|
||||||
|
config_dict["model_type"] = "deepseek_v3"
|
||||||
|
config_dict["architectures"] = ["MistralLarge3ForCausalLM"]
|
||||||
|
|
||||||
|
assert "llama_4_scaling" in config_dict, (
|
||||||
|
"MistralLarge3 expect llama4 scaling config."
|
||||||
|
)
|
||||||
|
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
|
||||||
|
assert all(
|
||||||
|
[
|
||||||
|
key in config_dict["llama_4_scaling"]
|
||||||
|
for key in llama_4_scaling_config_keys
|
||||||
|
]
|
||||||
|
), (
|
||||||
|
"llama_4_scaling config should define the keys: "
|
||||||
|
f"{','.join(llama_4_scaling_config_keys)}"
|
||||||
|
)
|
||||||
|
elif is_moe:
|
||||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||||
else:
|
else:
|
||||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||||
@ -140,17 +162,20 @@ def _remap_general_mistral_args(config: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def _remap_mistral_quantization_args(config: dict) -> dict:
|
def _remap_mistral_quantization_args(config: dict) -> dict:
|
||||||
quantization = config.get("quantization", {})
|
if config.get("quantization"):
|
||||||
if quantization.get("qformat_weight") == "fp8_e4m3":
|
quantization = config.pop("quantization", {})
|
||||||
# This maps to the FP8 static per-tensor quantization scheme
|
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||||
quantization_config = {"quant_method": "fp8", "activation_scheme": "static"}
|
qscheme_act = quantization.get("qscheme_act")
|
||||||
elif quantization.get("quant_method") == "compressed-tensors":
|
assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
|
||||||
# Pass through the quantization config to compressed-tensors
|
"Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
|
||||||
quantization_config = quantization
|
)
|
||||||
else:
|
is_dynamic = qscheme_act == "NO_SCALES"
|
||||||
raise ValueError(f"Found unknown quantization='{quantization}' in config")
|
config["quantization_config"] = {
|
||||||
|
"quant_method": "fp8",
|
||||||
config["quantization_config"] = quantization_config
|
"activation_scheme": "dynamic" if is_dynamic else "static",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Found unknown quantization='{quantization}' in config")
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@ -183,3 +208,28 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
|||||||
if quant_config:
|
if quant_config:
|
||||||
config["quantization_config"] = quant_config
|
config["quantization_config"] = quant_config
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _remap_moe_args(config: dict) -> dict:
|
||||||
|
moe_config_map = {
|
||||||
|
"route_every_n": "moe_layer_freq",
|
||||||
|
"first_k_dense_replace": "first_k_dense_replace",
|
||||||
|
"num_experts_per_tok": "num_experts_per_tok",
|
||||||
|
"num_experts": "n_routed_experts",
|
||||||
|
"expert_hidden_dim": "moe_intermediate_size",
|
||||||
|
"routed_scale": "routed_scaling_factor",
|
||||||
|
"num_shared_experts": "n_shared_experts",
|
||||||
|
"num_expert_groups": "n_group",
|
||||||
|
"num_expert_groups_per_tok": "topk_group",
|
||||||
|
}
|
||||||
|
moe_config = config.get("moe", {})
|
||||||
|
for old_name, new_name in moe_config_map.items():
|
||||||
|
if old_name in moe_config:
|
||||||
|
value = moe_config.pop(old_name)
|
||||||
|
config[new_name] = value
|
||||||
|
|
||||||
|
config["topk_method"] = None
|
||||||
|
config["norm_topk_prob"] = True
|
||||||
|
config["scoring_func"] = "softmax"
|
||||||
|
|
||||||
|
return config
|
||||||
|
|||||||
@ -1016,6 +1016,10 @@ class EagleProposer:
|
|||||||
"Qwen3VLForConditionalGeneration",
|
"Qwen3VLForConditionalGeneration",
|
||||||
]:
|
]:
|
||||||
self.model.config.image_token_index = target_model.config.image_token_id
|
self.model.config.image_token_index = target_model.config.image_token_id
|
||||||
|
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
|
||||||
|
self.model.config.image_token_index = (
|
||||||
|
target_model.config.vision_config.image_token_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.model.config.image_token_index = (
|
self.model.config.image_token_index = (
|
||||||
target_model.config.image_token_index
|
target_model.config.image_token_index
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user