mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:25:01 +08:00
Fix Arcee model weight loading: Add custom load_weights (#21725)
Signed-off-by: alyosha-swamy <raghav@arcee.ai>
This commit is contained in:
parent
1539ced93a
commit
a5fff3bd49
@ -139,8 +139,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
|
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base",
|
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
|
||||||
is_available_online=False),
|
|
||||||
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
|
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
|
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
|
||||||
|
|||||||
@ -24,10 +24,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers)
|
make_empty_intermediate_tensors_factory, make_layers)
|
||||||
|
|
||||||
|
|
||||||
@ -260,6 +262,81 @@ class ArceeModel(nn.Module):
|
|||||||
return hidden_states, aux_hidden_states
|
return hidden_states, aux_hidden_states
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
"""Load weights, mapping q/k/v projections to fused qkv_proj."""
|
||||||
|
stacked_params_mapping = [
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "scale" in name:
|
||||||
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if remapped_name is None:
|
||||||
|
continue
|
||||||
|
name = remapped_name
|
||||||
|
|
||||||
|
mapped = False
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
mapped = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
mapped = True
|
||||||
|
break
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader # type: ignore[attr-defined]
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
loaded_params.add(name)
|
||||||
|
mapped = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if mapped:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
"""Arcee Model for causal language modeling, integrated with vLLM
|
"""Arcee Model for causal language modeling, integrated with vLLM
|
||||||
@ -304,8 +381,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
else:
|
else:
|
||||||
# Placeholder for lm_head on non-last ranks
|
# Placeholder for lm_head on non-last ranks
|
||||||
self.lm_head = PPMissingLayer()
|
self.lm_head = PPMissingLayer()
|
||||||
# Provide a reference to the model's method for generating empty
|
|
||||||
# tensors (used in pipeline parallel schedule)
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
@ -316,7 +392,6 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None
|
inputs_embeds: Optional[torch.Tensor] = None
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
# Forward pass through the Arcee model backbone
|
|
||||||
model_output = self.model(input_ids=input_ids,
|
model_output = self.model(input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user