diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 3bd3c6fb1898..d900b0f9bcfb 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -235,17 +235,28 @@ class GPTBigCodeForCausalLM(nn.Module): continue param = state_dict[name] - + + if not name.startswith("transformer."): + name = "transformer." + name + + if name == "transformer.wte.weight": + # Consider padding in the vocab size. + padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + def _expand_mqa_mha(qkv_array, n_head, head_dim): """manipulates along axis=0 from MQA to MHA inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim) with n_heads for q, then 1 for k, 1 for 1 v, times head dim return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim) - + TODO: this function is no longer needed once vllm supports MQA. """ qkv_array = qkv_array.numpy() - + dims_q = n_head * head_dim q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0) # q is fine, but k & v have not replicated shape along the first axis @@ -285,6 +296,7 @@ class GPTBigCodeForCausalLM(nn.Module): loaded_weight = loaded_weight.reshape(-1) else: raise ValueError(f"Unexpected parameter name {name}") + load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights, diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index cef67fde587f..a7a002435da4 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -98,7 +98,9 @@ def load_tensor_parallel_weights( shard_size * tensor_model_parallel_rank :shard_size * (tensor_model_parallel_rank + 1)] break - assert param.shape == loaded_weight.shape + assert param.shape == loaded_weight.shape, ( + f"{param_name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}") param.data.copy_(loaded_weight)