expand coverage of gpt2 model loading (#271)

This commit is contained in:
twaka 2023-06-27 22:27:41 +09:00 committed by GitHub
parent 43710e8d09
commit 4026a049d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -228,11 +228,13 @@ class GPT2LMHeadModel(nn.Module):
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name:
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
name = "transformer." + name
if not name.startswith("transformer."):
name = "transformer." + name
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.