mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:15:01 +08:00
[Quantization] Enable BNB support for more MoE models (#21100)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
217937221b
commit
466e878f2a
@ -316,7 +316,7 @@ Specified using `--task generate`.
|
|||||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
|
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | | ✅︎ | ✅︎ |
|
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||||
@ -328,8 +328,8 @@ Specified using `--task generate`.
|
|||||||
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ |
|
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ |
|
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ |
|
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | | ✅︎ | ✅︎ |
|
| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. | | ✅︎ | ✅︎ |
|
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ |
|
||||||
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
|
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
|
||||||
@ -351,7 +351,7 @@ Specified using `--task generate`.
|
|||||||
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
|
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
|
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
|
||||||
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
|
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | | | ✅︎ |
|
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ |
|
||||||
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
@ -374,6 +374,14 @@ class BailingMoeModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -381,14 +389,10 @@ class BailingMoeModel(nn.Module):
|
|||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.num_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if self.config.norm_head and "lm_head.weight" in name:
|
if self.config.norm_head and "lm_head.weight" in name:
|
||||||
loaded_weight = F.normalize(loaded_weight,
|
loaded_weight = F.normalize(loaded_weight,
|
||||||
@ -449,7 +453,7 @@ class BailingMoeModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class BailingMoeForCausalLM(nn.Module, SupportsPP):
|
class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"query_key_value": ["query_key_value"],
|
"query_key_value": ["query_key_value"],
|
||||||
@ -518,3 +522,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
if self.config.tie_word_embeddings else None),
|
if self.config.tie_word_embeddings else None),
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -51,8 +51,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (PPMissingLayer, extract_layer_index,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
@ -427,66 +427,15 @@ class Ernie4_5_MoeModel(nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
|
||||||
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP):
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
packed_modules_mapping = {
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
"qkv_proj": [
|
return FusedMoE.make_expert_params_mapping(
|
||||||
"q_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
"k_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
"v_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
],
|
num_experts=self.config.moe_num_experts)
|
||||||
"gate_up_proj": [
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
fall_back_to_pt_during_load = False
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
self.config = config
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.model = Ernie4_5_MoeModel(vllm_config=vllm_config,
|
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config)
|
|
||||||
else:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
self.model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.model.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
@ -499,16 +448,9 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.moe_num_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if self.config.tie_word_embeddings and name.endswith(
|
if self.config.tie_word_embeddings and name.endswith(
|
||||||
"lm_head.weight"):
|
"lm_head.weight"):
|
||||||
@ -581,3 +523,76 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
fall_back_to_pt_during_load = False
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Ernie4_5_MoeModel(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -360,6 +360,16 @@ class Grok1Model(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
# Map Grok1's unique expert parameter names to standard names
|
||||||
|
# Grok1 uses "num_experts" in its config
|
||||||
|
num_experts = getattr(self.config, "num_experts", 8)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="linear", # Grok1 specific
|
||||||
|
ckpt_down_proj_name="linear_1", # Grok1 specific
|
||||||
|
ckpt_up_proj_name="linear_v", # Grok1 specific
|
||||||
|
num_experts=num_experts)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -369,18 +379,9 @@ class Grok1Model(nn.Module):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Map Grok1's unique expert parameter names to standard names
|
|
||||||
# Grok1 uses "num_experts" in its config
|
|
||||||
num_experts = getattr(self.config, "num_experts", 8)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="linear", # Grok1 specific
|
|
||||||
ckpt_down_proj_name="linear_1", # Grok1 specific
|
|
||||||
ckpt_up_proj_name="linear_v", # Grok1 specific
|
|
||||||
num_experts=num_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if (self.quant_config is not None and
|
if (self.quant_config is not None and
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
@ -544,3 +545,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
skip_prefixes=skip_prefixes,
|
skip_prefixes=skip_prefixes,
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -56,7 +56,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
from .interfaces import SupportsLoRA
|
||||||
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_layers)
|
||||||
|
|
||||||
|
|
||||||
def _get_cla_factor(config: PretrainedConfig) -> int:
|
def _get_cla_factor(config: PretrainedConfig) -> int:
|
||||||
@ -617,86 +619,6 @@ class HunYuanModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class HunYuanMoEV1ForCausalLM(nn.Module):
|
|
||||||
packed_modules_mapping = {
|
|
||||||
"qkv_proj": [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
],
|
|
||||||
"gate_up_proj": [
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
lora_config = vllm_config.lora_config
|
|
||||||
self.config = config
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.lora_config = lora_config
|
|
||||||
|
|
||||||
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
|
|
||||||
if get_pp_group().is_last_rank:
|
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
|
||||||
if lora_config:
|
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
self.unpadded_vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
org_num_embeddings=config.vocab_size,
|
|
||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
|
||||||
quant_config=quant_config,
|
|
||||||
)
|
|
||||||
if config.tie_word_embeddings:
|
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
config.vocab_size,
|
|
||||||
logit_scale)
|
|
||||||
else:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def make_empty_intermediate_tensors(
|
|
||||||
self, batch_size: int, dtype: torch.dtype,
|
|
||||||
device: torch.device) -> IntermediateTensors:
|
|
||||||
return IntermediateTensors({
|
|
||||||
"hidden_states":
|
|
||||||
torch.zeros((batch_size, self.config.hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device),
|
|
||||||
"residual":
|
|
||||||
torch.zeros((batch_size, self.config.hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device),
|
|
||||||
})
|
|
||||||
|
|
||||||
def _split_qkv_weight(self, qkv: torch.Tensor):
|
def _split_qkv_weight(self, qkv: torch.Tensor):
|
||||||
num_attention_heads = self.config.num_attention_heads
|
num_attention_heads = self.config.num_attention_heads
|
||||||
num_kv_heads = getattr(self.config, "num_key_value_heads",
|
num_kv_heads = getattr(self.config, "num_key_value_heads",
|
||||||
@ -719,6 +641,17 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
|
|||||||
v = v.reshape(-1, hidden_size)
|
v = v.reshape(-1, hidden_size)
|
||||||
return torch.concat((q, k, v))
|
return torch.concat((q, k, v))
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
cla_factor = _get_cla_factor(self.config)
|
cla_factor = _get_cla_factor(self.config)
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -745,16 +678,9 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.num_experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
@ -806,7 +732,7 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
|
|||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
loaded_params.add(name)
|
||||||
is_found = True
|
is_found = True
|
||||||
break
|
break
|
||||||
if is_found:
|
if is_found:
|
||||||
@ -885,3 +811,93 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class HunYuanMoEV1ForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def make_empty_intermediate_tensors(
|
||||||
|
self, batch_size: int, dtype: torch.dtype,
|
||||||
|
device: torch.device) -> IntermediateTensors:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states":
|
||||||
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device),
|
||||||
|
"residual":
|
||||||
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device),
|
||||||
|
})
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user