Fix Arcee model weight loading: Add custom load_weights (#21725)

Signed-off-by: alyosha-swamy <raghav@arcee.ai>
This commit is contained in:
Raghav Ravishankar 2025-08-04 16:39:56 +05:30 committed by GitHub
parent 1539ced93a
commit a5fff3bd49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 6 deletions

View File

@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
trust_remote_code=True),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base",
is_available_online=False),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",

View File

@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer,
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@ -260,6 +262,81 @@ class ArceeModel(nn.Module):
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
"""Load weights, mapping q/k/v projections to fused qkv_proj."""
stacked_params_mapping = [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is None:
continue
name = remapped_name
mapped = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
mapped = True
break
if is_pp_missing_parameter(name, self):
mapped = True
break
param = params_dict[name]
weight_loader = param.weight_loader # type: ignore[attr-defined]
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
mapped = True
break
if mapped:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"""Arcee Model for causal language modeling, integrated with vLLM
@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
# Placeholder for lm_head on non-last ranks
self.lm_head = PPMissingLayer()
# Provide a reference to the model's method for generating empty
# tensors (used in pipeline parallel schedule)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
# Forward pass through the Arcee model backbone
model_output = self.model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,