[Bugfix] Weight loading fix for OPT model (#9042)

Co-authored-by: dvres <dvres@fri.uni-lj.si>
This commit is contained in:
Domen Vreš 2024-10-04 01:53:29 +02:00 committed by GitHub
parent 91add85ec4
commit 2838d6b38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -353,7 +353,7 @@ class OPTForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
if "lm_head.weight" in name and self.config.tie_word_embeddings:
continue
if name.startswith("decoder."):
name = "model." + name