Fix the KeyError when loading bloom-based models (#441)

This commit is contained in:
Wen Sun 2023-07-14 12:58:09 +08:00 committed by GitHub
parent 7b6ae94059
commit dbed69058c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -284,10 +284,17 @@ class BloomForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.
self._column_parallel_weights.append(name)
# If lm_head is provided, use it instead.
param = self.lm_head_weight
else:
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the