mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 13:05:44 +08:00
Fix Baichuan2-7B-Chat (#1987)
This commit is contained in:
parent
6ccc0bfffb
commit
2b981012a6
@ -366,12 +366,16 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
class BaichuanForCausalLM(BaiChuanBaseForCausalLM
|
||||||
|
): # baichuan 13b, baichuan2 13b, baichuan2 7b
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config,
|
config,
|
||||||
linear_method: Optional[LinearMethodBase] = None):
|
linear_method: Optional[LinearMethodBase] = None):
|
||||||
super().__init__(config, "ALIBI", linear_method)
|
if config.hidden_size == 4096: # baichuan2 7b
|
||||||
|
super().__init__(config, "ROPE", linear_method)
|
||||||
|
else: # baichuan 13b, baichuan2 13b
|
||||||
|
super().__init__(config, "ALIBI", linear_method)
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user