From 466e878f2ad5e36cba4861db1cac7cd0d92055fb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 19 Jul 2025 08:52:02 +0800 Subject: [PATCH] [Quantization] Enable BNB support for more MoE models (#21100) Signed-off-by: Jee Jee Li --- docs/models/supported_models.md | 8 +- vllm/model_executor/models/bailing_moe.py | 21 +- vllm/model_executor/models/ernie45_moe.py | 153 +++++++------- vllm/model_executor/models/grok1.py | 24 ++- vllm/model_executor/models/hunyuan_v1_moe.py | 198 ++++++++++--------- 5 files changed, 223 insertions(+), 181 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 8fd8b8220cf7..cfd525ab9314 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -316,7 +316,7 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | | ✅︎ | ✅︎ | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | @@ -328,8 +328,8 @@ Specified using `--task generate`. | `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ | | `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. | | ✅︎ | ✅︎ | +| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | @@ -351,7 +351,7 @@ Specified using `--task generate`. | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | | | ✅︎ | +| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index ccfc3997e45c..853c13b135ea 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -374,6 +374,14 @@ class BailingMoeModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -381,14 +389,10 @@ class BailingMoeModel(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if self.config.norm_head and "lm_head.weight" in name: loaded_weight = F.normalize(loaded_weight, @@ -449,7 +453,7 @@ class BailingMoeModel(nn.Module): return loaded_params -class BailingMoeForCausalLM(nn.Module, SupportsPP): +class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "query_key_value": ["query_key_value"], @@ -518,3 +522,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index e7a50ff7a1c9..984003e62d11 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -51,8 +51,8 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP -from .utils import (PPMissingLayer, extract_layer_index, +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -427,66 +427,15 @@ class Ernie4_5_MoeModel(nn.Module): return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: -class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = Ernie4_5_MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - else: - self.lm_head = PPMissingLayer() - - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.moe_num_experts) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -499,16 +448,9 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP): ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.moe_num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if self.config.tie_word_embeddings and name.endswith( "lm_head.weight"): @@ -581,3 +523,76 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 2d930527b2be..3659249cd8bd 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -360,6 +360,16 @@ class Grok1Model(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Map Grok1's unique expert parameter names to standard names + # Grok1 uses "num_experts" in its config + num_experts = getattr(self.config, "num_experts", 8) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", # Grok1 specific + ckpt_down_proj_name="linear_1", # Grok1 specific + ckpt_up_proj_name="linear_v", # Grok1 specific + num_experts=num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -369,18 +379,9 @@ class Grok1Model(nn.Module): ("qkv_proj", "v_proj", "v"), ] - # Map Grok1's unique expert parameter names to standard names - # Grok1 uses "num_experts" in its config - num_experts = getattr(self.config, "num_experts", 8) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="linear", # Grok1 specific - ckpt_down_proj_name="linear_1", # Grok1 specific - ckpt_up_proj_name="linear_v", # Grok1 specific - num_experts=num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -544,3 +545,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): skip_prefixes=skip_prefixes, ) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/hunyuan_v1_moe.py b/vllm/model_executor/models/hunyuan_v1_moe.py index 43ffba00721e..b3baec98b0fc 100644 --- a/vllm/model_executor/models/hunyuan_v1_moe.py +++ b/vllm/model_executor/models/hunyuan_v1_moe.py @@ -56,7 +56,9 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_layers) def _get_cla_factor(config: PretrainedConfig) -> int: @@ -617,86 +619,6 @@ class HunYuanModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class HunYuanMoEV1ForCausalLM(nn.Module): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - self.config = config - self.quant_config = quant_config - self.lora_config = lora_config - - self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") - if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - else: - self.lm_head = PPMissingLayer() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return model_output - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def _split_qkv_weight(self, qkv: torch.Tensor): num_attention_heads = self.config.num_attention_heads num_kv_heads = getattr(self.config, "num_key_value_heads", @@ -719,6 +641,17 @@ class HunYuanMoEV1ForCausalLM(nn.Module): v = v.reshape(-1, hidden_size) return torch.concat((q, k, v)) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): cla_factor = _get_cla_factor(self.config) stacked_params_mapping = [ @@ -745,16 +678,9 @@ class HunYuanMoEV1ForCausalLM(nn.Module): ), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - ) - params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -806,7 +732,7 @@ class HunYuanMoEV1ForCausalLM(nn.Module): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - + loaded_params.add(name) is_found = True break if is_found: @@ -885,3 +811,93 @@ class HunYuanMoEV1ForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class HunYuanMoEV1ForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping()