[Bugfix] Fix TeleChat2ForCausalLM weights mapper (#11546)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-12-27 18:39:15 +08:00 committed by GitHub
parent d003f3ea39
commit 2c9b8ea2b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,19 +31,6 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
class TeleChat2Model(LlamaModel):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"transformer.": "model.",
},
orig_to_new_substr={
".h.": ".layers.",
".self_attention.": ".self_attn.",
".word_embeddings.": ".embed_tokens.",
".dense.": ".o_proj.",
".ln_f.": ".norm.",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# 1. Initialize the LlamaModel with bias
vllm_config.model_config.hf_config.bias = True
@ -118,6 +105,19 @@ class TeleChat2Model(LlamaModel):
class TeleChat2ForCausalLM(LlamaForCausalLM):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"transformer.": "model.",
},
orig_to_new_substr={
".h.": ".layers.",
".self_attention.": ".self_attn.",
".word_embeddings.": ".embed_tokens.",
".dense.": ".o_proj.",
".ln_f.": ".norm.",
},
)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)