mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 20:17:16 +08:00
Add Mistral Large 3 and Ministral 3 (#29757)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Signed-off-by: Mickael Seznec <mickael@mistral.ai> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
parent
8bbcf8b6e7
commit
d8c6210eea
@ -417,7 +417,8 @@ th {
|
||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
|
||||
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `MistralForCausalLM` | Ministral-3, Mistral, Mistral-Instruct | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `MistralLarge3ForCausalLM` | Mistral-Large-3-675B-Base-2512, Mistral-Large-3-675B-Instruct-2512 | `mistralai/Mistral-Large-3-675B-Base-2512`, `mistralai/Mistral-Large-3-675B-Instruct-2512`, etc. | ✅︎ | ✅︎ |
|
||||
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ |
|
||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ |
|
||||
@ -711,7 +712,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
|
||||
| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ |
|
||||
| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ |
|
||||
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ |
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -358,6 +358,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False
|
||||
),
|
||||
"MixtralForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
{"tiny": "TitanML/tiny-mixtral"},
|
||||
@ -770,7 +773,13 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"PixtralForConditionalGeneration": _HfExamplesInfo(
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
extras={
|
||||
"mistral-large-3": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4",
|
||||
"ministral-3": "mistralai/Ministral-3-3B-Instruct-2512",
|
||||
},
|
||||
tokenizer_mode="mistral",
|
||||
# TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available.
|
||||
is_available_online=False,
|
||||
),
|
||||
"QwenVLForConditionalGeneration": _HfExamplesInfo(
|
||||
"Qwen/Qwen-VL",
|
||||
@ -870,6 +879,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
use_original_num_layers=True,
|
||||
max_model_len=10240,
|
||||
),
|
||||
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512",
|
||||
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
|
||||
is_available_online=False,
|
||||
),
|
||||
"LlamaForCausalLMEagle3": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-8B",
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -91,6 +91,118 @@ from vllm.tokenizers.mistral import (
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"unsupported_field": False,
|
||||
"name": "get_current_time",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"unsupported_field2": False,
|
||||
"name": "get_current_time",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"unsupported_field": False,
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"unsupported_field2": False,
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time 2.",
|
||||
"name": "get_current_time2",
|
||||
"parameters": {"a": "1"},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time 2.",
|
||||
"name": "get_current_time2",
|
||||
"parameters": {"a": "1"},
|
||||
},
|
||||
},
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prepare_apply_chat_template_tools_and_messages(
|
||||
@ -1108,13 +1220,6 @@ class TestMistralTokenizer:
|
||||
)
|
||||
== expected_tokens[mistral_tokenizer.is_tekken]
|
||||
)
|
||||
assert (
|
||||
mistral_tokenizer.decode(
|
||||
ids[mistral_tokenizer.is_tekken],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
== expected_tokens[mistral_tokenizer.is_tekken]
|
||||
)
|
||||
|
||||
def test_decode_empty(
|
||||
self,
|
||||
@ -1140,6 +1245,45 @@ class TestMistralTokenizer:
|
||||
== "<s>"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"skip_special_tokens,expected_tokens",
|
||||
(
|
||||
(
|
||||
False,
|
||||
(
|
||||
["<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>"],
|
||||
["<s>[INST]Hello world ![/INST]Hello</s>"],
|
||||
),
|
||||
),
|
||||
(True, (["Hello world ! Hello"], ["Hello world !Hello"])),
|
||||
),
|
||||
)
|
||||
def test_batch_decode(
|
||||
self,
|
||||
mistral_tokenizer: MistralTokenizer,
|
||||
skip_special_tokens: bool,
|
||||
expected_tokens: tuple[str, str],
|
||||
):
|
||||
ids = (
|
||||
[[1, 3, 23325, 2294, 1686, 4, 23325, 2]],
|
||||
[[1, 3, 22177, 4304, 2662, 4, 22177, 2]],
|
||||
)
|
||||
assert (
|
||||
mistral_tokenizer.batch_decode(
|
||||
ids[mistral_tokenizer.is_tekken],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
== expected_tokens[mistral_tokenizer.is_tekken]
|
||||
)
|
||||
|
||||
def test_batch_decode_empty(
|
||||
self,
|
||||
mistral_tokenizer: MistralTokenizer,
|
||||
):
|
||||
assert mistral_tokenizer.batch_decode(
|
||||
[[]],
|
||||
) == [""]
|
||||
|
||||
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
|
||||
tokens = (
|
||||
[
|
||||
|
||||
@ -167,6 +167,7 @@ class SpeculativeConfig:
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
initial_architecture = hf_config.architectures[0]
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
@ -226,6 +227,9 @@ class SpeculativeConfig:
|
||||
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
|
||||
)
|
||||
|
||||
if initial_architecture == "MistralLarge3ForCausalLM":
|
||||
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
|
||||
|
||||
return hf_config
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@ -80,7 +80,7 @@ class MistralToolParser(ToolParser):
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(
|
||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL
|
||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)?", re.DOTALL
|
||||
)
|
||||
else:
|
||||
self.fn_name_regex = None
|
||||
|
||||
@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -111,6 +111,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
q_c = None
|
||||
kv_lora = None
|
||||
@ -159,6 +160,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
hidden_states, q_c, positions, self.indexer_rope_emb
|
||||
)
|
||||
|
||||
if llama_4_scaling is not None:
|
||||
q *= llama_4_scaling
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
||||
@ -238,7 +238,7 @@ def get_rope(
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
|
||||
@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def _get_llama_4_scaling(
|
||||
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
scaling = 1 + scaling_beta * torch.log(
|
||||
1 + torch.floor(positions / original_max_position_embeddings)
|
||||
)
|
||||
# Broadcast over num_heads and head_dim
|
||||
return scaling[..., None, None]
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
config.rope_parameters["rope_type"] = "deepseek_yarn"
|
||||
config.rope_parameters["rope_type"] = (
|
||||
"deepseek_yarn"
|
||||
if config.rope_parameters.get("apply_yarn_scaling", True)
|
||||
else "deepseek_llama_scaling"
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module):
|
||||
is_neox_style=False,
|
||||
)
|
||||
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
if (
|
||||
config.rope_parameters["rope_type"] != "default"
|
||||
and config.rope_parameters["rope_type"] == "deepseek_yarn"
|
||||
):
|
||||
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
||||
scaling_factor = config.rope_parameters["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module):
|
||||
k = torch.empty_like(q)
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
# Apply llama 4 scaling if provided
|
||||
if llama_4_scaling is not None:
|
||||
q *= llama_4_scaling
|
||||
|
||||
# padding value to qk_head_dim for alignment
|
||||
v = torch.nn.functional.pad(
|
||||
v, [0, self.qk_head_dim - self.v_head_dim], value=0
|
||||
@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
)
|
||||
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
config.rope_parameters["rope_type"] = "deepseek_yarn"
|
||||
config.rope_parameters["rope_type"] = (
|
||||
"deepseek_yarn"
|
||||
if config.rope_parameters.get("apply_yarn_scaling", True)
|
||||
else "deepseek_llama_scaling"
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=False,
|
||||
)
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
|
||||
if (
|
||||
config.rope_parameters["rope_type"] != "default"
|
||||
and config.rope_parameters["rope_type"] == "deepseek_yarn"
|
||||
):
|
||||
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
||||
scaling_factor = config.rope_parameters["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return self.mla_attn(positions, hidden_states)
|
||||
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
llama_4_scaling=llama_4_scaling,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# Compute llama 4 scaling once per forward pass if enabled
|
||||
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
|
||||
llama_4_scaling: torch.Tensor | None
|
||||
if llama_4_scaling_config is not None:
|
||||
llama_4_scaling = _get_llama_4_scaling(
|
||||
original_max_position_embeddings=llama_4_scaling_config[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
scaling_beta=llama_4_scaling_config["beta"],
|
||||
positions=positions,
|
||||
)
|
||||
else:
|
||||
llama_4_scaling = None
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
model_cls = DeepseekV2Model
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM(
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
|
||||
self.model = DeepseekV2Model(
|
||||
self.model = self.model_cls(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
|
||||
63
vllm/model_executor/models/mistral_large_3.py
Normal file
63
vllm/model_executor/models/mistral_large_3.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV3ForCausalLM
|
||||
|
||||
|
||||
class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
# fmt: off
|
||||
remapping = {
|
||||
r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501
|
||||
r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501
|
||||
r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501
|
||||
r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501
|
||||
r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501
|
||||
r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501
|
||||
r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501
|
||||
r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501
|
||||
r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501
|
||||
r"norm\.weight": "model.norm.weight", # noqa: E501
|
||||
r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501
|
||||
r"output\.weight": "lm_head.weight", # noqa: E501
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
return super().load_weights(map(self._remap_mistral_to_ds, weights))
|
||||
|
||||
def _remap_mistral_to_ds(
|
||||
self, weight: tuple[str, torch.Tensor]
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
"""Remap Mistral parameters to DeepseekV2 parameters."""
|
||||
name, loaded_weight = weight
|
||||
|
||||
for k, v in self.remapping.items():
|
||||
match = re.fullmatch(k, name)
|
||||
if match:
|
||||
name = re.sub(k, v, name)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Cannot remap {name}")
|
||||
|
||||
# Remapping scale names. We could do this in the regex above but it
|
||||
# would triple the number of lines for most layers.
|
||||
if name.endswith(".qscale_act"):
|
||||
name = re.sub(r"\.qscale_act$", ".input_scale", name)
|
||||
elif name.endswith(".qscale_weight"):
|
||||
name = re.sub(r"\.qscale_weight$", ".weight_scale", name)
|
||||
|
||||
return name, loaded_weight
|
||||
165
vllm/model_executor/models/mistral_large_3_eagle.py
Normal file
165
vllm/model_executor/models/mistral_large_3_eagle.py
Normal file
@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.models.deepseek_v2 import (
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2Model,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import (
|
||||
_merge_multimodal_embeddings,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class EagleMistralLarge3Model(DeepseekV2Model):
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
assert get_pp_group().world_size == 1
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DeepseekV2DecoderLayer(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.start_layer = 0
|
||||
self.end_layer = self.config.num_hidden_layers
|
||||
|
||||
self.fc = RowParallelLinear(
|
||||
self.config.hidden_size * 2,
|
||||
self.config.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_input_ids(input_ids)
|
||||
inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
|
||||
output = super().forward(
|
||||
input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds
|
||||
)
|
||||
assert isinstance(output, torch.Tensor)
|
||||
return output
|
||||
|
||||
|
||||
class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
|
||||
remapping = MistralLarge3ForCausalLM.remapping | {
|
||||
r"eagle_linear\.weight": r"model.fc.weight",
|
||||
r"eagle_linear\.qscale_act": r"model.fc.input_scale",
|
||||
r"eagle_linear\.qscale_weight": r"model.fc.weight_scale",
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
vllm_config.model_config = vllm_config.speculative_config.draft_model_config
|
||||
# draft model quantization config may differ from target model
|
||||
self.quant_config = VllmConfig.get_quantization_config(
|
||||
vllm_config.speculative_config.draft_model_config, vllm_config.load_config
|
||||
)
|
||||
vllm_config.quant_config = self.quant_config
|
||||
self.model_cls = partial(
|
||||
EagleMistralLarge3Model, start_layer_id=target_layer_num
|
||||
)
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = super().embed_input_ids(input_ids)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
assert is_multimodal is not None
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||
return hidden_states, hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
# Pretend we've loaded the embedding and lm_head weights
|
||||
# (later copied from target model)
|
||||
return super().load_weights(weights) | {
|
||||
"model.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
}
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors | None = None,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
@ -145,6 +145,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
@ -424,6 +425,10 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"EagleMistralLarge3ForCausalLM": (
|
||||
"mistral_large_3_eagle",
|
||||
"EagleMistralLarge3ForCausalLM",
|
||||
),
|
||||
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
||||
|
||||
@ -97,6 +97,8 @@ def _prepare_apply_chat_template_tools_and_messages(
|
||||
continue_final_message: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
if add_generation_prompt and continue_final_message:
|
||||
raise ValueError(
|
||||
"Cannot set both `add_generation_prompt` and "
|
||||
@ -139,6 +141,33 @@ def _prepare_apply_chat_template_tools_and_messages(
|
||||
if function.get("description") is None:
|
||||
function["description"] = ""
|
||||
|
||||
# We filter not supported arguments to avoid throwing an error.
|
||||
# TODO(juliendenize): remove this once OpenAI API is better supported by
|
||||
# `mistral-common`.
|
||||
tools_fields = set(Tool.model_fields.keys())
|
||||
function_fields = set(Function.model_fields.keys())
|
||||
for tool in tools:
|
||||
tool_keys = list(tool.keys())
|
||||
for tool_key in tool_keys:
|
||||
if tool_key not in tools_fields:
|
||||
tool.pop(tool_key)
|
||||
logger.warning_once(
|
||||
f"'{tool_key}' is not supported by mistral-common for tools. "
|
||||
"It has been poped from the tool definition."
|
||||
)
|
||||
if tool["type"] == "function":
|
||||
function_keys = list(tool["function"].keys())
|
||||
for function_key in function_keys:
|
||||
if function_key not in function_fields:
|
||||
tool["function"].pop(function_key)
|
||||
logger.warning_once(
|
||||
f"'{function_key}' is not supported by mistral-common "
|
||||
"for function tools. It has been poped from the "
|
||||
"function definition."
|
||||
)
|
||||
else:
|
||||
raise ValueError("mistral-common only supports function tools.")
|
||||
|
||||
return messages, tools
|
||||
|
||||
|
||||
@ -410,6 +439,13 @@ class MistralTokenizer(TokenizerLike):
|
||||
ids, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
||||
def batch_decode(
|
||||
self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
|
||||
) -> str:
|
||||
return self.transformers_tokenizer.batch_decode(
|
||||
ids, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
from mistral_common.tokens.tokenizers.base import (
|
||||
SpecialTokenPolicy,
|
||||
|
||||
@ -82,3 +82,9 @@ class EAGLEConfig(PretrainedConfig):
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
def to_json_string(self, use_diff: bool = True) -> str:
|
||||
# we override use_diff to False as initializing
|
||||
# EAGLEConfig with default arguments is not supported
|
||||
del use_diff
|
||||
return super().to_json_string(use_diff=False)
|
||||
|
||||
@ -18,9 +18,31 @@ def adapt_config_dict(
|
||||
if bool(config_dict.get("quantization")):
|
||||
config_dict = _remap_mistral_quantization_args(config_dict)
|
||||
|
||||
is_moe = bool(config_dict.get("moe"))
|
||||
is_mistral_large_3 = (
|
||||
is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
|
||||
)
|
||||
if config_dict.get("model_type") == "mamba":
|
||||
config_dict["architectures"] = ["Mamba2ForCausalLM"]
|
||||
elif bool(config_dict.get("moe")):
|
||||
elif is_moe and is_mistral_large_3:
|
||||
config_dict = _remap_moe_args(config_dict)
|
||||
config_dict["model_type"] = "deepseek_v3"
|
||||
config_dict["architectures"] = ["MistralLarge3ForCausalLM"]
|
||||
|
||||
assert "llama_4_scaling" in config_dict, (
|
||||
"MistralLarge3 expect llama4 scaling config."
|
||||
)
|
||||
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
|
||||
assert all(
|
||||
[
|
||||
key in config_dict["llama_4_scaling"]
|
||||
for key in llama_4_scaling_config_keys
|
||||
]
|
||||
), (
|
||||
"llama_4_scaling config should define the keys: "
|
||||
f"{','.join(llama_4_scaling_config_keys)}"
|
||||
)
|
||||
elif is_moe:
|
||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||
else:
|
||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||
@ -140,17 +162,20 @@ def _remap_general_mistral_args(config: dict) -> dict:
|
||||
|
||||
|
||||
def _remap_mistral_quantization_args(config: dict) -> dict:
|
||||
quantization = config.get("quantization", {})
|
||||
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||
# This maps to the FP8 static per-tensor quantization scheme
|
||||
quantization_config = {"quant_method": "fp8", "activation_scheme": "static"}
|
||||
elif quantization.get("quant_method") == "compressed-tensors":
|
||||
# Pass through the quantization config to compressed-tensors
|
||||
quantization_config = quantization
|
||||
else:
|
||||
raise ValueError(f"Found unknown quantization='{quantization}' in config")
|
||||
|
||||
config["quantization_config"] = quantization_config
|
||||
if config.get("quantization"):
|
||||
quantization = config.pop("quantization", {})
|
||||
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||
qscheme_act = quantization.get("qscheme_act")
|
||||
assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
|
||||
"Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
|
||||
)
|
||||
is_dynamic = qscheme_act == "NO_SCALES"
|
||||
config["quantization_config"] = {
|
||||
"quant_method": "fp8",
|
||||
"activation_scheme": "dynamic" if is_dynamic else "static",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Found unknown quantization='{quantization}' in config")
|
||||
|
||||
return config
|
||||
|
||||
@ -183,3 +208,28 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
if quant_config:
|
||||
config["quantization_config"] = quant_config
|
||||
return config
|
||||
|
||||
|
||||
def _remap_moe_args(config: dict) -> dict:
|
||||
moe_config_map = {
|
||||
"route_every_n": "moe_layer_freq",
|
||||
"first_k_dense_replace": "first_k_dense_replace",
|
||||
"num_experts_per_tok": "num_experts_per_tok",
|
||||
"num_experts": "n_routed_experts",
|
||||
"expert_hidden_dim": "moe_intermediate_size",
|
||||
"routed_scale": "routed_scaling_factor",
|
||||
"num_shared_experts": "n_shared_experts",
|
||||
"num_expert_groups": "n_group",
|
||||
"num_expert_groups_per_tok": "topk_group",
|
||||
}
|
||||
moe_config = config.get("moe", {})
|
||||
for old_name, new_name in moe_config_map.items():
|
||||
if old_name in moe_config:
|
||||
value = moe_config.pop(old_name)
|
||||
config[new_name] = value
|
||||
|
||||
config["topk_method"] = None
|
||||
config["norm_topk_prob"] = True
|
||||
config["scoring_func"] = "softmax"
|
||||
|
||||
return config
|
||||
|
||||
@ -1016,6 +1016,10 @@ class EagleProposer:
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
]:
|
||||
self.model.config.image_token_index = target_model.config.image_token_id
|
||||
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
|
||||
self.model.config.image_token_index = (
|
||||
target_model.config.vision_config.image_token_id
|
||||
)
|
||||
else:
|
||||
self.model.config.image_token_index = (
|
||||
target_model.config.image_token_index
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user