mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
Normalize head weights for Baichuan 2 (#1876)
This commit is contained in:
parent
d06980dfa7
commit
e5452ddfd6
@ -47,7 +47,7 @@ vLLM is flexible and easy to use with:
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||
|
||||
@ -19,7 +19,7 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||
* - :code:`BaiChuanForCausalLM`
|
||||
- Baichuan
|
||||
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
|
||||
- :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc.
|
||||
* - :code:`ChatGLMModel`
|
||||
- ChatGLM
|
||||
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
||||
|
||||
@ -341,6 +341,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if name == "lm_head.weight":
|
||||
# Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to:
|
||||
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
||||
# Distinguish between Baichuan and Baichuan2 by checking the
|
||||
# vocab size. This is suggested by
|
||||
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
||||
is_baichuan2 = self.config.vocab_size == 125696
|
||||
if is_baichuan2:
|
||||
loaded_weight = torch.nn.functional.normalize(
|
||||
loaded_weight)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user