[Bugfix] Fix KeyError on loading GPT-NeoX (#3925)

This commit is contained in:
Junichi Sato 2024-04-10 04:11:31 +09:00 committed by GitHub
parent e7c7067b45
commit e23a43aef8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -274,6 +274,11 @@ class GPTNeoXForCausalLM(nn.Module):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name: