mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:35:50 +08:00
Fix the KeyError when loading bloom-based models (#441)
This commit is contained in:
parent
7b6ae94059
commit
dbed69058c
@ -284,10 +284,17 @@ class BloomForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, use_np_cache):
|
||||||
if not name.startswith("transformer."):
|
if name == "lm_head.weight":
|
||||||
name = "transformer." + name
|
# Since hidden_states are parallelized, we need to
|
||||||
|
# load lm_head.weight in parallel.
|
||||||
|
self._column_parallel_weights.append(name)
|
||||||
|
# If lm_head is provided, use it instead.
|
||||||
|
param = self.lm_head_weight
|
||||||
|
else:
|
||||||
|
if not name.startswith("transformer."):
|
||||||
|
name = "transformer." + name
|
||||||
|
param = state_dict[name]
|
||||||
|
|
||||||
param = state_dict[name]
|
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
||||||
# [num_heads * 3 * head_size, hidden_size], while the
|
# [num_heads * 3 * head_size, hidden_size], while the
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user