diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index eba8207d2cd4..3bd6332c11ca 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers, - maybe_prefix) +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_layers, maybe_prefix) class GraniteMLP(nn.Module): @@ -260,6 +260,7 @@ class GraniteModel(nn.Module): lora_config = vllm_config.lora_config self.config = config + self.quant_config = quant_config lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab @@ -321,6 +322,65 @@ class GraniteModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -428,71 +488,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), + skip_prefixes = [ + "rotary_emb.inv_freq", + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + "rotary_emb.cos_cached", + "rotary_emb.sin_cached", ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings: + skip_prefixes.append("lm_head.weight") - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 5152539c68f6..367722126e56 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from . import mixtral from .interfaces import SupportsLoRA, SupportsPP -from .utils import make_layers, maybe_prefix +from .utils import AutoWeightsLoader, make_layers, maybe_prefix class GraniteMoeMoE(nn.Module): @@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module): quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + self.config = config + self.quant_config = quant_config # Required by MixtralModel lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab @@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + else: + new_weights[n] = p + return mixtral.MixtralModel.load_weights(self, new_weights.items()) + class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False @@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.config = config self.lora_config = lora_config - self.quant_config = quant_config # Required by MixtralForCausalLM self.model = GraniteMoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - new_weights = {} - for n, p in weights: - if n.endswith('.block_sparse_moe.input_linear.weight'): - for e in range(p.size(0)): - w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") - w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") - w1_param, w3_param = p[e].chunk(2, dim=0) - assert w1_name not in new_weights - assert w3_name not in new_weights - new_weights[w1_name] = w1_param - new_weights[w3_name] = w3_param - elif n.endswith('.block_sparse_moe.output_linear.weight'): - for e in range(p.size(0)): - w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") - w2_param = p[e] - assert w2_name not in new_weights - new_weights[w2_name] = w2_param - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") - assert gate_name not in new_weights - new_weights[gate_name] = p - elif n == 'lm_head.weight' and self.config.tie_word_embeddings: - pass - else: - new_weights[n] = p - return mixtral.MixtralForCausalLM.load_weights(self, - new_weights.items()) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 7e2e4cdcbfa3..cf8c969e118f 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors from . import mixtral from .granitemoe import GraniteMoeAttention, GraniteMoeMoE from .interfaces import SupportsLoRA, SupportsPP -from .utils import make_layers, maybe_prefix +from .utils import AutoWeightsLoader, make_layers, maybe_prefix class GraniteMoeSharedMLP(nn.Module): @@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module): quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + self.config = config + self.quant_config = quant_config # Required by MixtralModel self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + else: + new_weights[n] = p + return mixtral.MixtralModel.load_weights(self, new_weights.items()) + class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False @@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.config = config self.lora_config = lora_config - self.quant_config = quant_config self.model = GraniteMoeSharedModel(vllm_config=vllm_config, prefix=maybe_prefix( @@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - new_weights = {} - for n, p in weights: - if n.endswith('.block_sparse_moe.input_linear.weight'): - for e in range(p.size(0)): - w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") - w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") - w1_param, w3_param = p[e].chunk(2, dim=0) - assert w1_name not in new_weights - assert w3_name not in new_weights - new_weights[w1_name] = w1_param - new_weights[w3_name] = w3_param - elif n.endswith('.block_sparse_moe.output_linear.weight'): - for e in range(p.size(0)): - w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") - w2_param = p[e] - assert w2_name not in new_weights - new_weights[w2_name] = w2_param - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") - assert gate_name not in new_weights - new_weights[gate_name] = p - elif n == 'lm_head.weight' and self.config.tie_word_embeddings: - pass - else: - new_weights[n] = p - return mixtral.MixtralForCausalLM.load_weights(self, - new_weights.items()) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index f2e82017f653..ef96257ba4bb 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -302,6 +302,8 @@ class Grok1Model(nn.Module): quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + self.config = config + self.quant_config = quant_config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -370,6 +372,105 @@ class Grok1Model(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Map Grok1's unique expert parameter names to standard names + # Grok1 uses "num_experts" in its config + num_experts = getattr(self.config, "num_experts", 8) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", # Grok1 specific + ckpt_down_proj_name="linear_1", # Grok1 specific + ckpt_up_proj_name="linear_v", # Grok1 specific + num_experts=num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + # Handle Grok1-specific norm.scale naming + if "norm.scale" in name: + name = name.replace("scale", "weight") + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False @@ -460,106 +561,10 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] + skip_prefixes = ["rotary_emb.inv_freq"] + # Skip lm_head when tie_word_embeddings is True + if self.config.tie_word_embeddings: + skip_prefixes.append("lm_head") - # Map Grok1's unique expert parameter names to standard names - # Grok1 uses "num_experts" in its config - num_experts = getattr(self.config, "num_experts", 8) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="linear", # Grok1 specific - ckpt_down_proj_name="linear_1", # Grok1 specific - ckpt_up_proj_name="linear_v", # Grok1 specific - num_experts=num_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - # Handle Grok1-specific norm.scale naming - if "norm.scale" in name: - name = name.replace("scale", "weight") - - # Skip lm_head when tie_word_embeddings is True - if "lm_head" in name and self.config.tie_word_embeddings: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 6bdb623593a7..b0ac99f21ead 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -260,6 +260,8 @@ class MixtralModel(nn.Module): quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + self.config = config + self.quant_config = quant_config lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab @@ -313,6 +315,98 @@ class MixtralModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False @@ -397,95 +491,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) + return loader.load_weights(weights)