mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
Fix the KeyError when loading bloom-based models (#441)
This commit is contained in:
parent
7b6ae94059
commit
dbed69058c
@ -284,10 +284,17 @@ class BloomForCausalLM(nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
if name == "lm_head.weight":
|
||||
# Since hidden_states are parallelized, we need to
|
||||
# load lm_head.weight in parallel.
|
||||
self._column_parallel_weights.append(name)
|
||||
# If lm_head is provided, use it instead.
|
||||
param = self.lm_head_weight
|
||||
else:
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = state_dict[name]
|
||||
|
||||
param = state_dict[name]
|
||||
if "query_key_value" in name:
|
||||
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
||||
# [num_heads * 3 * head_size, hidden_size], while the
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user