mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
Fix Arcee model weight loading: Add custom load_weights (#21725)
Signed-off-by: alyosha-swamy <raghav@arcee.ai>
This commit is contained in:
parent
1539ced93a
commit
a5fff3bd49
@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
|
||||
trust_remote_code=True),
|
||||
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base",
|
||||
is_available_online=False),
|
||||
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
|
||||
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
|
||||
trust_remote_code=True),
|
||||
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
|
||||
|
||||
@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer,
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@ -260,6 +262,81 @@ class ArceeModel(nn.Module):
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
"""Load weights, mapping q/k/v projections to fused qkv_proj."""
|
||||
stacked_params_mapping = [
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
continue
|
||||
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if "scale" in name:
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is None:
|
||||
continue
|
||||
name = remapped_name
|
||||
|
||||
mapped = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
mapped = True
|
||||
break
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
mapped = True
|
||||
break
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader # type: ignore[attr-defined]
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
mapped = True
|
||||
break
|
||||
|
||||
if mapped:
|
||||
continue
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"""Arcee Model for causal language modeling, integrated with vLLM
|
||||
@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
else:
|
||||
# Placeholder for lm_head on non-last ranks
|
||||
self.lm_head = PPMissingLayer()
|
||||
# Provide a reference to the model's method for generating empty
|
||||
# tensors (used in pipeline parallel schedule)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
# Forward pass through the Arcee model backbone
|
||||
model_output = self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user