From a5fff3bd49a5ea888cf0dbdfe7ecf140455fa8d4 Mon Sep 17 00:00:00 2001 From: Raghav Ravishankar <113712354+alyosha-swamy@users.noreply.github.com> Date: Mon, 4 Aug 2025 16:39:56 +0530 Subject: [PATCH] Fix Arcee model weight loading: Add custom load_weights (#21725) Signed-off-by: alyosha-swamy --- tests/models/registry.py | 3 +- vllm/model_executor/models/arcee.py | 83 +++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index ffa6b755adf4..d86bd20fb0e3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), - "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base", - is_available_online=False), + "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"), "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 4e3ba107ba7e..4cf73e2e0ea5 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -260,6 +262,81 @@ class ArceeModel(nn.Module): return hidden_states, aux_hidden_states return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """Load weights, mapping q/k/v projections to fused qkv_proj.""" + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + if "scale" in name: + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is None: + continue + name = remapped_name + + mapped = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name.endswith(".bias") and name not in params_dict: + mapped = True + break + + if is_pp_missing_parameter(name, self): + mapped = True + break + + param = params_dict[name] + weight_loader = param.weight_loader # type: ignore[attr-defined] + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + mapped = True + break + + if mapped: + continue + + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): """Arcee Model for causal language modeling, integrated with vLLM @@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): else: # Placeholder for lm_head on non-last ranks self.lm_head = PPMissingLayer() - # Provide a reference to the model's method for generating empty - # tensors (used in pipeline parallel schedule) + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: - # Forward pass through the Arcee model backbone model_output = self.model(input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors,