From dbed69058c88ddf42914e6ab3a9b6ea12e15b12a Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Fri, 14 Jul 2023 12:58:09 +0800 Subject: [PATCH] Fix the `KeyError` when loading bloom-based models (#441) --- vllm/model_executor/models/bloom.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index ffc47d01cb0e..4a3de8d46979 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -284,10 +284,17 @@ class BloomForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, use_np_cache): - if not name.startswith("transformer."): - name = "transformer." + name + if name == "lm_head.weight": + # Since hidden_states are parallelized, we need to + # load lm_head.weight in parallel. + self._column_parallel_weights.append(name) + # If lm_head is provided, use it instead. + param = self.lm_head_weight + else: + if not name.startswith("transformer."): + name = "transformer." + name + param = state_dict[name] - param = state_dict[name] if "query_key_value" in name: # NOTE(woosuk): BLOOM's fused QKV has the shape of # [num_heads * 3 * head_size, hidden_size], while the